mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 14:14:59 +00:00
Merge branch 'master' into fix/stream_tokens_usage
This commit is contained in:
commit
9f6be2be12
79 changed files with 2716 additions and 941 deletions
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
|
@ -1,6 +1,11 @@
|
|||
enhancements:
|
||||
- head-branch: ['^feature', 'feature']
|
||||
|
||||
dependencies:
|
||||
- any:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'Makefile'
|
||||
|
||||
kind/documentation:
|
||||
- any:
|
||||
- changed-files:
|
||||
|
|
17
.github/workflows/bump_deps.yaml
vendored
17
.github/workflows/bump_deps.yaml
vendored
|
@ -12,23 +12,14 @@ jobs:
|
|||
- repository: "ggerganov/llama.cpp"
|
||||
variable: "CPPLLAMA_VERSION"
|
||||
branch: "master"
|
||||
- repository: "go-skynet/go-ggml-transformers.cpp"
|
||||
variable: "GOGGMLTRANSFORMERS_VERSION"
|
||||
branch: "master"
|
||||
- repository: "donomii/go-rwkv.cpp"
|
||||
variable: "RWKV_VERSION"
|
||||
branch: "main"
|
||||
- repository: "ggerganov/whisper.cpp"
|
||||
variable: "WHISPER_CPP_VERSION"
|
||||
branch: "master"
|
||||
- repository: "go-skynet/go-bert.cpp"
|
||||
variable: "BERT_VERSION"
|
||||
branch: "master"
|
||||
- repository: "go-skynet/bloomz.cpp"
|
||||
variable: "BLOOMZ_VERSION"
|
||||
- repository: "PABannier/bark.cpp"
|
||||
variable: "BARKCPP_VERSION"
|
||||
branch: "main"
|
||||
- repository: "mudler/go-ggllm.cpp"
|
||||
variable: "GOGGLLM_VERSION"
|
||||
- repository: "leejet/stable-diffusion.cpp"
|
||||
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||
branch: "master"
|
||||
- repository: "mudler/go-stable-diffusion"
|
||||
variable: "STABLEDIFFUSION_VERSION"
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,6 +2,7 @@
|
|||
/sources/
|
||||
__pycache__/
|
||||
*.a
|
||||
*.o
|
||||
get-sources
|
||||
prepare-sources
|
||||
/backend/cpp/llama/grpc-server
|
||||
|
|
72
Makefile
72
Makefile
|
@ -8,7 +8,7 @@ DETECT_LIBS?=true
|
|||
# llama.cpp versions
|
||||
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
|
||||
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
|
||||
CPPLLAMA_VERSION?=3ad5451f3b75809e3033e4e577b9f60bcaf6676a
|
||||
CPPLLAMA_VERSION?=08ea539df211e46bb4d0dd275e541cb591d5ebc8
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
|
||||
|
@ -26,6 +26,14 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
|
|||
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
||||
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
||||
|
||||
# bark.cpp
|
||||
BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
|
||||
BARKCPP_VERSION?=v1.0.0
|
||||
|
||||
# stablediffusion.cpp (ggml)
|
||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||
STABLEDIFFUSION_GGML_VERSION?=9578fdcc4632dc3de5565f28e2fb16b7c18f8d48
|
||||
|
||||
ONNX_VERSION?=1.20.0
|
||||
ONNX_ARCH?=x64
|
||||
ONNX_OS?=linux
|
||||
|
@ -201,6 +209,14 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
|
|||
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
|
||||
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
||||
|
||||
ifeq ($(ONNX_OS),linux)
|
||||
ifeq ($(ONNX_ARCH),x64)
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/bark-cpp
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/stablediffusion-ggml
|
||||
endif
|
||||
endif
|
||||
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
|
||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||
|
@ -236,6 +252,23 @@ sources/go-llama.cpp:
|
|||
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
|
||||
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
||||
|
||||
## bark.cpp
|
||||
sources/bark.cpp:
|
||||
git clone --recursive $(BARKCPP_REPO) sources/bark.cpp && \
|
||||
cd sources/bark.cpp && \
|
||||
git checkout $(BARKCPP_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
sources/bark.cpp/build/libbark.a: sources/bark.cpp
|
||||
cd sources/bark.cpp && \
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake $(CMAKE_ARGS) .. && \
|
||||
cmake --build . --config Release
|
||||
|
||||
backend/go/bark/libbark.a: sources/bark.cpp/build/libbark.a
|
||||
$(MAKE) -C backend/go/bark libbark.a
|
||||
|
||||
## go-piper
|
||||
sources/go-piper:
|
||||
mkdir -p sources/go-piper
|
||||
|
@ -249,7 +282,7 @@ sources/go-piper:
|
|||
sources/go-piper/libpiper_binding.a: sources/go-piper
|
||||
$(MAKE) -C sources/go-piper libpiper_binding.a example/main piper.o
|
||||
|
||||
## stable diffusion
|
||||
## stable diffusion (onnx)
|
||||
sources/go-stable-diffusion:
|
||||
mkdir -p sources/go-stable-diffusion
|
||||
cd sources/go-stable-diffusion && \
|
||||
|
@ -262,6 +295,30 @@ sources/go-stable-diffusion:
|
|||
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
|
||||
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
|
||||
|
||||
## stablediffusion (ggml)
|
||||
sources/stablediffusion-ggml.cpp:
|
||||
git clone --recursive $(STABLEDIFFUSION_GGML_REPO) sources/stablediffusion-ggml.cpp && \
|
||||
cd sources/stablediffusion-ggml.cpp && \
|
||||
git checkout $(STABLEDIFFUSION_GGML_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a: sources/stablediffusion-ggml.cpp
|
||||
cd sources/stablediffusion-ggml.cpp && \
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake $(CMAKE_ARGS) .. && \
|
||||
cmake --build . --config Release
|
||||
|
||||
backend/go/image/stablediffusion-ggml/libsd.a: sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a
|
||||
$(MAKE) -C backend/go/image/stablediffusion-ggml libsd.a
|
||||
|
||||
backend-assets/grpc/stablediffusion-ggml: backend/go/image/stablediffusion-ggml/libsd.a backend-assets/grpc
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/image/stablediffusion-ggml/ LIBRARY_PATH=$(CURDIR)/backend/go/image/stablediffusion-ggml/ \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion-ggml ./backend/go/image/stablediffusion-ggml/
|
||||
ifneq ($(UPX),)
|
||||
$(UPX) backend-assets/grpc/stablediffusion-ggml
|
||||
endif
|
||||
|
||||
sources/onnxruntime:
|
||||
mkdir -p sources/onnxruntime
|
||||
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
|
||||
|
@ -302,7 +359,7 @@ sources/whisper.cpp:
|
|||
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
|
||||
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a
|
||||
|
||||
get-sources: sources/go-llama.cpp sources/go-piper sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
|
||||
get-sources: sources/go-llama.cpp sources/go-piper sources/stablediffusion-ggml.cpp sources/bark.cpp sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
|
||||
|
||||
replace:
|
||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
|
||||
|
@ -343,7 +400,9 @@ clean: ## Remove build related file
|
|||
rm -rf release/
|
||||
rm -rf backend-assets/*
|
||||
$(MAKE) -C backend/cpp/grpc clean
|
||||
$(MAKE) -C backend/go/bark clean
|
||||
$(MAKE) -C backend/cpp/llama clean
|
||||
$(MAKE) -C backend/go/image/stablediffusion-ggml clean
|
||||
rm -rf backend/cpp/llama-* || true
|
||||
$(MAKE) dropreplace
|
||||
$(MAKE) protogen-clean
|
||||
|
@ -792,6 +851,13 @@ ifneq ($(UPX),)
|
|||
$(UPX) backend-assets/grpc/llama-ggml
|
||||
endif
|
||||
|
||||
backend-assets/grpc/bark-cpp: backend/go/bark/libbark.a backend-assets/grpc
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/bark/ LIBRARY_PATH=$(CURDIR)/backend/go/bark/ \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bark-cpp ./backend/go/bark/
|
||||
ifneq ($(UPX),)
|
||||
$(UPX) backend-assets/grpc/bark-cpp
|
||||
endif
|
||||
|
||||
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
|
||||
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
|
||||
|
|
|
@ -92,6 +92,8 @@ local-ai run oci://localai/phi-2:latest
|
|||
|
||||
## 📰 Latest project news
|
||||
|
||||
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
|
||||
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
|
||||
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
|
||||
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
|
||||
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
||||
|
|
|
@ -240,6 +240,11 @@ message ModelOptions {
|
|||
|
||||
repeated string LoraAdapters = 60;
|
||||
repeated float LoraScales = 61;
|
||||
|
||||
repeated string Options = 62;
|
||||
|
||||
string CacheTypeKey = 63;
|
||||
string CacheTypeValue = 64;
|
||||
}
|
||||
|
||||
message Result {
|
||||
|
|
|
@ -681,7 +681,6 @@ struct llama_server_context
|
|||
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
|
@ -1213,13 +1212,12 @@ struct llama_server_context
|
|||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
{"penalize_nl", slot.sparams.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"n_predict", slot.params.n_predict},
|
||||
{"n_keep", params.n_keep},
|
||||
{"ignore_eos", slot.sparams.ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
// {"logit_bias", slot.sparams.logit_bias},
|
||||
// {"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
|
@ -2112,7 +2110,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
|||
// slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
// slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
// slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
// slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
// slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
// slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
// slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
|
@ -2135,7 +2132,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
|||
data["mirostat"] = predict->mirostat();
|
||||
data["mirostat_tau"] = predict->mirostattau();
|
||||
data["mirostat_eta"] = predict->mirostateta();
|
||||
data["penalize_nl"] = predict->penalizenl();
|
||||
data["n_keep"] = predict->nkeep();
|
||||
data["seed"] = predict->seed();
|
||||
data["grammar"] = predict->grammar();
|
||||
|
@ -2181,7 +2177,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
|||
// llama.params.sparams.mirostat = predict->mirostat();
|
||||
// llama.params.sparams.mirostat_tau = predict->mirostattau();
|
||||
// llama.params.sparams.mirostat_eta = predict->mirostateta();
|
||||
// llama.params.sparams.penalize_nl = predict->penalizenl();
|
||||
// llama.params.n_keep = predict->nkeep();
|
||||
// llama.params.seed = predict->seed();
|
||||
// llama.params.sparams.grammar = predict->grammar();
|
||||
|
@ -2228,6 +2223,35 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
|||
// }
|
||||
// }
|
||||
|
||||
const std::vector<ggml_type> kv_cache_types = {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
GGML_TYPE_BF16,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_IQ4_NL,
|
||||
GGML_TYPE_Q5_0,
|
||||
GGML_TYPE_Q5_1,
|
||||
};
|
||||
|
||||
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||
for (const auto & type : kv_cache_types) {
|
||||
if (ggml_type_name(type) == s) {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unsupported cache type: " + s);
|
||||
}
|
||||
|
||||
static std::string get_all_kv_cache_types() {
|
||||
std::ostringstream msg;
|
||||
for (const auto & type : kv_cache_types) {
|
||||
msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", ");
|
||||
}
|
||||
return msg.str();
|
||||
}
|
||||
|
||||
static void params_parse(const backend::ModelOptions* request,
|
||||
common_params & params) {
|
||||
|
||||
|
@ -2241,6 +2265,12 @@ static void params_parse(const backend::ModelOptions* request,
|
|||
}
|
||||
// params.model_alias ??
|
||||
params.model_alias = request->modelfile();
|
||||
if (!request->cachetypekey().empty()) {
|
||||
params.cache_type_k = kv_cache_type_from_str(request->cachetypekey());
|
||||
}
|
||||
if (!request->cachetypevalue().empty()) {
|
||||
params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue());
|
||||
}
|
||||
params.n_ctx = request->contextsize();
|
||||
//params.memory_f16 = request->f16memory();
|
||||
params.cpuparams.n_threads = request->threads();
|
||||
|
|
25
backend/go/bark/Makefile
Normal file
25
backend/go/bark/Makefile
Normal file
|
@ -0,0 +1,25 @@
|
|||
INCLUDE_PATH := $(abspath ./)
|
||||
LIBRARY_PATH := $(abspath ./)
|
||||
|
||||
AR?=ar
|
||||
|
||||
BUILD_TYPE?=
|
||||
# keep standard at C11 and C++11
|
||||
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
|
||||
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm
|
||||
|
||||
# warnings
|
||||
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
|
||||
|
||||
gobark.o:
|
||||
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)
|
||||
|
||||
libbark.a: gobark.o
|
||||
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./
|
||||
$(AR) rcs libbark.a gobark.o
|
||||
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o
|
||||
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o
|
||||
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o
|
||||
|
||||
clean:
|
||||
rm -f gobark.o libbark.a
|
85
backend/go/bark/gobark.cpp
Normal file
85
backend/go/bark/gobark.cpp
Normal file
|
@ -0,0 +1,85 @@
|
|||
#include <iostream>
|
||||
#include <tuple>
|
||||
|
||||
#include "bark.h"
|
||||
#include "gobark.h"
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
struct bark_context *c;
|
||||
|
||||
void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
|
||||
if (step == bark_encoding_step::SEMANTIC) {
|
||||
printf("\rGenerating semantic tokens... %d%%", progress);
|
||||
} else if (step == bark_encoding_step::COARSE) {
|
||||
printf("\rGenerating coarse tokens... %d%%", progress);
|
||||
} else if (step == bark_encoding_step::FINE) {
|
||||
printf("\rGenerating fine tokens... %d%%", progress);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
int load_model(char *model) {
|
||||
// initialize bark context
|
||||
struct bark_context_params ctx_params = bark_context_default_params();
|
||||
bark_params params;
|
||||
|
||||
params.model_path = model;
|
||||
|
||||
// ctx_params.verbosity = verbosity;
|
||||
ctx_params.progress_callback = bark_print_progress_callback;
|
||||
ctx_params.progress_callback_user_data = nullptr;
|
||||
|
||||
struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
|
||||
if (!bctx) {
|
||||
fprintf(stderr, "%s: Could not load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
c = bctx;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tts(char *text,int threads, char *dst ) {
|
||||
|
||||
ggml_time_init();
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
|
||||
// generate audio
|
||||
if (!bark_generate_audio(c, text, threads)) {
|
||||
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const float *audio_data = bark_get_audio_data(c);
|
||||
if (audio_data == NULL) {
|
||||
fprintf(stderr, "%s: Could not get audio data\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int audio_arr_size = bark_get_audio_data_size(c);
|
||||
|
||||
std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);
|
||||
|
||||
write_wav_on_disk(audio_arr, dst);
|
||||
|
||||
// report timing
|
||||
{
|
||||
const int64_t t_main_end_us = ggml_time_us();
|
||||
const int64_t t_load_us = bark_get_load_time(c);
|
||||
const int64_t t_eval_us = bark_get_eval_time(c);
|
||||
|
||||
printf("\n\n");
|
||||
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
|
||||
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
|
||||
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int unload() {
|
||||
bark_free(c);
|
||||
}
|
||||
|
52
backend/go/bark/gobark.go
Normal file
52
backend/go/bark/gobark.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package main
|
||||
|
||||
// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon
|
||||
// #include <gobark.h>
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type Bark struct {
|
||||
base.SingleThread
|
||||
threads int
|
||||
}
|
||||
|
||||
func (sd *Bark) Load(opts *pb.ModelOptions) error {
|
||||
|
||||
sd.threads = int(opts.Threads)
|
||||
|
||||
modelFile := C.CString(opts.ModelFile)
|
||||
defer C.free(unsafe.Pointer(modelFile))
|
||||
|
||||
ret := C.load_model(modelFile)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("inference failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *Bark) TTS(opts *pb.TTSRequest) error {
|
||||
t := C.CString(opts.Text)
|
||||
defer C.free(unsafe.Pointer(t))
|
||||
|
||||
dst := C.CString(opts.Dst)
|
||||
defer C.free(unsafe.Pointer(dst))
|
||||
|
||||
threads := C.int(sd.threads)
|
||||
|
||||
ret := C.tts(t, threads, dst)
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("inference failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
8
backend/go/bark/gobark.h
Normal file
8
backend/go/bark/gobark.h
Normal file
|
@ -0,0 +1,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int load_model(char *model);
|
||||
int tts(char *text,int threads, char *dst );
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
20
backend/go/bark/main.go
Normal file
20
backend/go/bark/main.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &Bark{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
21
backend/go/image/stablediffusion-ggml/Makefile
Normal file
21
backend/go/image/stablediffusion-ggml/Makefile
Normal file
|
@ -0,0 +1,21 @@
|
|||
INCLUDE_PATH := $(abspath ./)
|
||||
LIBRARY_PATH := $(abspath ./)
|
||||
|
||||
AR?=ar
|
||||
|
||||
BUILD_TYPE?=
|
||||
# keep standard at C11 and C++11
|
||||
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/thirdparty -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/ggml/include -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp -O3 -DNDEBUG -std=c++17 -fPIC
|
||||
|
||||
# warnings
|
||||
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
|
||||
|
||||
gosd.o:
|
||||
$(CXX) $(CXXFLAGS) gosd.cpp -o gosd.o -c
|
||||
|
||||
libsd.a: gosd.o
|
||||
cp $(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a ./libsd.a
|
||||
$(AR) rcs libsd.a gosd.o
|
||||
|
||||
clean:
|
||||
rm -f gosd.o libsd.a
|
228
backend/go/image/stablediffusion-ggml/gosd.cpp
Normal file
228
backend/go/image/stablediffusion-ggml/gosd.cpp
Normal file
|
@ -0,0 +1,228 @@
|
|||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "gosd.h"
|
||||
|
||||
// #include "preprocessing.hpp"
|
||||
#include "flux.hpp"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#define STB_IMAGE_STATIC
|
||||
#include "stb_image.h"
|
||||
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
#define STB_IMAGE_WRITE_STATIC
|
||||
#include "stb_image_write.h"
|
||||
|
||||
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||
#define STB_IMAGE_RESIZE_STATIC
|
||||
#include "stb_image_resize.h"
|
||||
|
||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||
const char* sample_method_str[] = {
|
||||
"euler_a",
|
||||
"euler",
|
||||
"heun",
|
||||
"dpm2",
|
||||
"dpm++2s_a",
|
||||
"dpm++2m",
|
||||
"dpm++2mv2",
|
||||
"ipndm",
|
||||
"ipndm_v",
|
||||
"lcm",
|
||||
};
|
||||
|
||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||
const char* schedule_str[] = {
|
||||
"default",
|
||||
"discrete",
|
||||
"karras",
|
||||
"exponential",
|
||||
"ays",
|
||||
"gits",
|
||||
};
|
||||
|
||||
sd_ctx_t* sd_c;
|
||||
|
||||
sample_method_t sample_method;
|
||||
|
||||
int load_model(char *model, char* options[], int threads, int diff) {
|
||||
fprintf (stderr, "Loading model!\n");
|
||||
|
||||
char *stableDiffusionModel = "";
|
||||
if (diff == 1 ) {
|
||||
stableDiffusionModel = model;
|
||||
model = "";
|
||||
}
|
||||
|
||||
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||
char *clip_l_path = "";
|
||||
char *clip_g_path = "";
|
||||
char *t5xxl_path = "";
|
||||
char *vae_path = "";
|
||||
char *scheduler = "";
|
||||
char *sampler = "";
|
||||
|
||||
// If options is not NULL, parse options
|
||||
for (int i = 0; options[i] != NULL; i++) {
|
||||
char *optname = strtok(options[i], ":");
|
||||
char *optval = strtok(NULL, ":");
|
||||
if (optval == NULL) {
|
||||
optval = "true";
|
||||
}
|
||||
|
||||
if (!strcmp(optname, "clip_l_path")) {
|
||||
clip_l_path = optval;
|
||||
}
|
||||
if (!strcmp(optname, "clip_g_path")) {
|
||||
clip_g_path = optval;
|
||||
}
|
||||
if (!strcmp(optname, "t5xxl_path")) {
|
||||
t5xxl_path = optval;
|
||||
}
|
||||
if (!strcmp(optname, "vae_path")) {
|
||||
vae_path = optval;
|
||||
}
|
||||
if (!strcmp(optname, "scheduler")) {
|
||||
scheduler = optval;
|
||||
}
|
||||
if (!strcmp(optname, "sampler")) {
|
||||
sampler = optval;
|
||||
}
|
||||
}
|
||||
|
||||
int sample_method_found = -1;
|
||||
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
|
||||
if (!strcmp(sampler, sample_method_str[m])) {
|
||||
sample_method_found = m;
|
||||
}
|
||||
}
|
||||
if (sample_method_found == -1) {
|
||||
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
|
||||
sample_method_found = EULER_A;
|
||||
}
|
||||
sample_method = (sample_method_t)sample_method_found;
|
||||
|
||||
int schedule_found = -1;
|
||||
for (int d = 0; d < N_SCHEDULES; d++) {
|
||||
if (!strcmp(scheduler, schedule_str[d])) {
|
||||
schedule_found = d;
|
||||
fprintf (stderr, "Found scheduler: %s\n", scheduler);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if (schedule_found == -1) {
|
||||
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
|
||||
schedule_found = DEFAULT;
|
||||
}
|
||||
|
||||
schedule_t schedule = (schedule_t)schedule_found;
|
||||
|
||||
fprintf (stderr, "Creating context\n");
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(model,
|
||||
clip_l_path,
|
||||
clip_g_path,
|
||||
t5xxl_path,
|
||||
stableDiffusionModel,
|
||||
vae_path,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
threads,
|
||||
SD_TYPE_COUNT,
|
||||
STD_DEFAULT_RNG,
|
||||
schedule,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
fprintf (stderr, "failed loading model (generic error)\n");
|
||||
return 1;
|
||||
}
|
||||
fprintf (stderr, "Created context: OK\n");
|
||||
|
||||
sd_c = sd_ctx;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) {
|
||||
|
||||
sd_image_t* results;
|
||||
|
||||
std::vector<int> skip_layers = {7, 8, 9};
|
||||
|
||||
fprintf (stderr, "Generating image\n");
|
||||
|
||||
results = txt2img(sd_c,
|
||||
text,
|
||||
negativeText,
|
||||
-1, //clip_skip
|
||||
cfg_scale, // sfg_scale
|
||||
3.5f,
|
||||
width,
|
||||
height,
|
||||
sample_method,
|
||||
steps,
|
||||
seed,
|
||||
1,
|
||||
NULL,
|
||||
0.9f,
|
||||
20.f,
|
||||
false,
|
||||
"",
|
||||
skip_layers.data(),
|
||||
skip_layers.size(),
|
||||
0,
|
||||
0.01,
|
||||
0.2);
|
||||
|
||||
if (results == NULL) {
|
||||
fprintf (stderr, "NO results\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (results[0].data == NULL) {
|
||||
fprintf (stderr, "Results with no data\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf (stderr, "Writing PNG\n");
|
||||
|
||||
fprintf (stderr, "DST: %s\n", dst);
|
||||
fprintf (stderr, "Width: %d\n", results[0].width);
|
||||
fprintf (stderr, "Height: %d\n", results[0].height);
|
||||
fprintf (stderr, "Channel: %d\n", results[0].channel);
|
||||
fprintf (stderr, "Data: %p\n", results[0].data);
|
||||
|
||||
stbi_write_png(dst, results[0].width, results[0].height, results[0].channel,
|
||||
results[0].data, 0, NULL);
|
||||
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
|
||||
|
||||
// TODO: free results. Why does it crash?
|
||||
|
||||
free(results[0].data);
|
||||
results[0].data = NULL;
|
||||
free(results);
|
||||
fprintf (stderr, "gen_image is done", dst);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int unload() {
|
||||
free_sd_ctx(sd_c);
|
||||
}
|
||||
|
96
backend/go/image/stablediffusion-ggml/gosd.go
Normal file
96
backend/go/image/stablediffusion-ggml/gosd.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package main
|
||||
|
||||
// #cgo CXXFLAGS: -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/thirdparty -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/ggml/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/build/ggml/src/ggml-cpu -L${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/build/ggml/src -lsd -lstdc++ -lm -lggml -lggml-base -lggml-cpu -lgomp
|
||||
// #include <gosd.h>
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
type SDGGML struct {
|
||||
base.SingleThread
|
||||
threads int
|
||||
sampleMethod string
|
||||
cfgScale float32
|
||||
}
|
||||
|
||||
func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
|
||||
|
||||
sd.threads = int(opts.Threads)
|
||||
|
||||
modelFile := C.CString(opts.ModelFile)
|
||||
defer C.free(unsafe.Pointer(modelFile))
|
||||
|
||||
var options **C.char
|
||||
// prepare the options array to pass to C
|
||||
|
||||
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||
length := C.size_t(len(opts.Options))
|
||||
options = (**C.char)(C.malloc(length * size))
|
||||
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0:len(opts.Options):len(opts.Options)]
|
||||
|
||||
var diffusionModel int
|
||||
|
||||
var oo []string
|
||||
for _, op := range opts.Options {
|
||||
if op == "diffusion_model" {
|
||||
diffusionModel = 1
|
||||
continue
|
||||
}
|
||||
|
||||
// If it's an option path, we resolve absolute path from the model path
|
||||
if strings.Contains(op, ":") && strings.Contains(op, "path") {
|
||||
data := strings.Split(op, ":")
|
||||
data[1] = filepath.Join(opts.ModelPath, data[1])
|
||||
if err := utils.VerifyPath(data[1], opts.ModelPath); err == nil {
|
||||
oo = append(oo, strings.Join(data, ":"))
|
||||
}
|
||||
} else {
|
||||
oo = append(oo, op)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Options: %+v\n", oo)
|
||||
|
||||
for i, x := range oo {
|
||||
view[i] = C.CString(x)
|
||||
}
|
||||
|
||||
sd.cfgScale = opts.CFGScale
|
||||
|
||||
ret := C.load_model(modelFile, options, C.int(opts.Threads), C.int(diffusionModel))
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("could not load model")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||
t := C.CString(opts.PositivePrompt)
|
||||
defer C.free(unsafe.Pointer(t))
|
||||
|
||||
dst := C.CString(opts.Dst)
|
||||
defer C.free(unsafe.Pointer(dst))
|
||||
|
||||
negative := C.CString(opts.NegativePrompt)
|
||||
defer C.free(unsafe.Pointer(negative))
|
||||
|
||||
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale))
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("inference failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
8
backend/go/image/stablediffusion-ggml/gosd.h
Normal file
8
backend/go/image/stablediffusion-ggml/gosd.h
Normal file
|
@ -0,0 +1,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int load_model(char *model, char* options[], int threads, int diffusionModel);
|
||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
20
backend/go/image/stablediffusion-ggml/main.go
Normal file
20
backend/go/image/stablediffusion-ggml/main.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &SDGGML{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -2,4 +2,4 @@
|
|||
intel-extension-for-pytorch
|
||||
torch
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
|
@ -1,6 +1,6 @@
|
|||
accelerate
|
||||
auto-gptq==0.7.1
|
||||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
transformers
|
|
@ -3,6 +3,6 @@ intel-extension-for-pytorch
|
|||
torch
|
||||
torchaudio
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
||||
transformers
|
||||
accelerate
|
|
@ -1,4 +1,4 @@
|
|||
bark==0.1.5
|
||||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
|
@ -17,6 +17,9 @@
|
|||
# LIMIT_TARGETS="cublas12"
|
||||
# source $(dirname $0)/../common/libbackend.sh
|
||||
#
|
||||
|
||||
PYTHON_VERSION="3.10"
|
||||
|
||||
function init() {
|
||||
# Name of the backend (directory name)
|
||||
BACKEND_NAME=${PWD##*/}
|
||||
|
@ -88,7 +91,7 @@ function getBuildProfile() {
|
|||
# always result in an activated virtual environment
|
||||
function ensureVenv() {
|
||||
if [ ! -d "${EDIR}/venv" ]; then
|
||||
uv venv ${EDIR}/venv
|
||||
uv venv --python ${PYTHON_VERSION} ${EDIR}/venv
|
||||
echo "virtualenv created"
|
||||
fi
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
grpcio-tools
|
|
@ -3,7 +3,7 @@ intel-extension-for-pytorch
|
|||
torch
|
||||
torchaudio
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
||||
transformers
|
||||
accelerate
|
||||
coqui-tts
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
|
@ -3,7 +3,7 @@ intel-extension-for-pytorch
|
|||
torch
|
||||
torchvision
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
||||
diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
setuptools
|
||||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
pillow
|
||||
protobuf
|
||||
certifi
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
wheel
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
|
@ -2,7 +2,7 @@
|
|||
intel-extension-for-pytorch
|
||||
torch
|
||||
optimum[openvino]
|
||||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
librosa==0.9.1
|
||||
faster-whisper==0.9.0
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
librosa
|
||||
faster-whisper
|
||||
|
@ -18,3 +18,4 @@ jieba==0.42.1
|
|||
gradio==3.48.0
|
||||
langid==1.1.6
|
||||
llvmlite==0.43.0
|
||||
setuptools
|
|
@ -3,6 +3,5 @@ intel-extension-for-pytorch
|
|||
torch
|
||||
torchaudio
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
transformers
|
||||
accelerate
|
|
@ -1,3 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
certifi
|
||||
llvmlite==0.43.0
|
||||
setuptools
|
|
@ -5,4 +5,4 @@ accelerate
|
|||
torch
|
||||
rerankers[transformers]
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
|
@ -1,3 +1,3 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
|
@ -2,7 +2,7 @@
|
|||
intel-extension-for-pytorch
|
||||
torch
|
||||
optimum[openvino]
|
||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
||||
accelerate
|
||||
sentence-transformers==3.3.1
|
||||
transformers
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
datasets
|
||||
|
|
|
@ -4,4 +4,4 @@ transformers
|
|||
accelerate
|
||||
torch
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
scipy==1.14.0
|
||||
certifi
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
|
@ -3,5 +3,4 @@ intel-extension-for-pytorch
|
|||
accelerate
|
||||
torch
|
||||
torchaudio
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
optimum[openvino]
|
|
@ -1,3 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
certifi
|
||||
setuptools
|
|
@ -22,7 +22,7 @@ if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE}" == "xtrue" ]; then
|
|||
git clone https://github.com/vllm-project/vllm
|
||||
fi
|
||||
pushd vllm
|
||||
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.0 protobuf bitsandbytes
|
||||
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.1 protobuf bitsandbytes
|
||||
uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
VLLM_TARGET_DEVICE=cpu python setup.py install
|
||||
popd
|
||||
|
|
|
@ -4,5 +4,5 @@ accelerate
|
|||
torch
|
||||
transformers
|
||||
optimum[openvino]
|
||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||
setuptools
|
||||
bitsandbytes
|
|
@ -1,4 +1,4 @@
|
|||
grpcio==1.68.0
|
||||
grpcio==1.68.1
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
|
@ -1,38 +0,0 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy
|
||||
// Perhaps a proper DI system is worth it in the future, but for now keep things simple.
|
||||
type Application struct {
|
||||
|
||||
// Application-Level Config
|
||||
ApplicationConfig *config.ApplicationConfig
|
||||
// ApplicationState *ApplicationState
|
||||
|
||||
// Core Low-Level Services
|
||||
BackendConfigLoader *config.BackendConfigLoader
|
||||
ModelLoader *model.ModelLoader
|
||||
|
||||
// Backend Services
|
||||
// EmbeddingsBackendService *backend.EmbeddingsBackendService
|
||||
// ImageGenerationBackendService *backend.ImageGenerationBackendService
|
||||
// LLMBackendService *backend.LLMBackendService
|
||||
// TranscriptionBackendService *backend.TranscriptionBackendService
|
||||
// TextToSpeechBackendService *backend.TextToSpeechBackendService
|
||||
|
||||
// LocalAI System Services
|
||||
BackendMonitorService *services.BackendMonitorService
|
||||
GalleryService *services.GalleryService
|
||||
LocalAIMetricsService *services.LocalAIMetricsService
|
||||
// OpenAIService *services.OpenAIService
|
||||
}
|
||||
|
||||
// TODO [NEXT PR?]: Break up ApplicationConfig.
|
||||
// Migrate over stuff that is not set via config at all - especially runtime stuff
|
||||
type ApplicationState struct {
|
||||
}
|
39
core/application/application.go
Normal file
39
core/application/application.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
backendLoader *config.BackendConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
templatesEvaluator *templates.Evaluator
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
return &Application{
|
||||
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||
modelLoader: model.NewModelLoader(appConfig.ModelPath),
|
||||
applicationConfig: appConfig,
|
||||
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) BackendLoader() *config.BackendConfigLoader {
|
||||
return a.backendLoader
|
||||
}
|
||||
|
||||
func (a *Application) ModelLoader() *model.ModelLoader {
|
||||
return a.modelLoader
|
||||
}
|
||||
|
||||
func (a *Application) ApplicationConfig() *config.ApplicationConfig {
|
||||
return a.applicationConfig
|
||||
}
|
||||
|
||||
func (a *Application) TemplatesEvaluator() *templates.Evaluator {
|
||||
return a.templatesEvaluator
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package startup
|
||||
package application
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
@ -8,8 +8,8 @@ import (
|
|||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"dario.cat/mergo"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
|
@ -1,15 +1,15 @@
|
|||
package startup
|
||||
package application
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/mudler/LocalAI/core"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/assets"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/library"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
|
||||
|
@ -17,8 +17,9 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
|
||||
func New(opts ...config.AppOption) (*Application, error) {
|
||||
options := config.NewApplicationConfig(opts...)
|
||||
application := newApplication(options)
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
@ -36,28 +37,28 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
|
||||
// Make sure directories exists
|
||||
if options.ModelPath == "" {
|
||||
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||
return nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||
}
|
||||
err = os.MkdirAll(options.ModelPath, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
}
|
||||
if options.ImageDir != "" {
|
||||
err := os.MkdirAll(options.ImageDir, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||
}
|
||||
}
|
||||
if options.AudioDir != "" {
|
||||
err := os.MkdirAll(options.AudioDir, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||
}
|
||||
}
|
||||
if options.UploadDir != "" {
|
||||
err := os.MkdirAll(options.UploadDir, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,39 +66,36 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
cl := config.NewBackendConfigLoader(options.ModelPath)
|
||||
ml := model.NewModelLoader(options.ModelPath)
|
||||
|
||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||
|
||||
if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
||||
if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config files")
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config file")
|
||||
}
|
||||
}
|
||||
|
||||
if err := cl.Preload(options.ModelPath); err != nil {
|
||||
if err := application.BackendLoader().Preload(options.ModelPath); err != nil {
|
||||
log.Error().Err(err).Msg("error downloading models")
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range cl.GetAllBackendConfigs() {
|
||||
for _, v := range application.BackendLoader().GetAllBackendConfigs() {
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
||||
}
|
||||
}
|
||||
|
@ -123,7 +121,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
err := ml.StopAllGRPC()
|
||||
err := application.ModelLoader().StopAllGRPC()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||
}
|
||||
|
@ -131,12 +129,12 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
|
||||
if options.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
ml,
|
||||
application.ModelLoader(),
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
ml.SetWatchDog(wd)
|
||||
application.ModelLoader().SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
|
@ -147,7 +145,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
|
||||
if options.LoadToMemory != nil {
|
||||
for _, m := range options.LoadToMemory {
|
||||
cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath,
|
||||
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
|
||||
config.LoadOptionDebug(options.Debug),
|
||||
config.LoadOptionThreads(options.Threads),
|
||||
config.LoadOptionContextSize(options.ContextSize),
|
||||
|
@ -155,7 +153,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
config.ModelPath(options.ModelPath),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
|
||||
|
@ -163,9 +161,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
o := backend.ModelOptions(*cfg, options)
|
||||
|
||||
var backendErr error
|
||||
_, backendErr = ml.Load(o...)
|
||||
_, backendErr = application.ModelLoader().Load(o...)
|
||||
if backendErr != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -174,7 +172,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
startWatcher(options)
|
||||
|
||||
log.Info().Msg("core/startup process completed!")
|
||||
return cl, ml, options, nil
|
||||
return application, nil
|
||||
}
|
||||
|
||||
func startWatcher(options *config.ApplicationConfig) {
|
||||
|
@ -201,32 +199,3 @@ func startWatcher(options *config.ApplicationConfig) {
|
|||
log.Error().Err(err).Msg("failed creating watcher")
|
||||
}
|
||||
}
|
||||
|
||||
// In Lieu of a proper DI framework, this function wires up the Application manually.
|
||||
// This is in core/startup rather than core/state.go to keep package references clean!
|
||||
func createApplication(appConfig *config.ApplicationConfig) *core.Application {
|
||||
app := &core.Application{
|
||||
ApplicationConfig: appConfig,
|
||||
BackendConfigLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
|
||||
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
app.GalleryService = services.NewGalleryService(app.ApplicationConfig)
|
||||
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
|
||||
|
||||
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.")
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
|
@ -122,7 +122,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
|||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||
SchedulerType: c.Diffusers.SchedulerType,
|
||||
PipelineType: c.Diffusers.PipelineType,
|
||||
CFGScale: c.Diffusers.CFGScale,
|
||||
CFGScale: c.CFGScale,
|
||||
LoraAdapter: c.LoraAdapter,
|
||||
LoraScale: c.LoraScale,
|
||||
LoraAdapters: c.LoraAdapters,
|
||||
|
@ -132,6 +132,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
|||
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||
CLIPModel: c.Diffusers.ClipModel,
|
||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||
Options: c.Options,
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
ControlNet: c.Diffusers.ControlNet,
|
||||
ContextSize: int32(ctxSize),
|
||||
|
@ -150,6 +151,8 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
|||
TensorParallelSize: int32(c.TensorParallelSize),
|
||||
MMProj: c.MMProj,
|
||||
FlashAttention: c.FlashAttention,
|
||||
CacheTypeKey: c.CacheTypeK,
|
||||
CacheTypeValue: c.CacheTypeV,
|
||||
NoKVOffload: c.NoKVOffloading,
|
||||
YarnExtFactor: c.YarnExtFactor,
|
||||
YarnAttnFactor: c.YarnAttnFactor,
|
||||
|
|
|
@ -6,12 +6,12 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -186,16 +186,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||
}
|
||||
|
||||
if r.PreloadBackendOnly {
|
||||
_, _, _, err := startup.Startup(opts...)
|
||||
_, err := application.New(opts...)
|
||||
return err
|
||||
}
|
||||
|
||||
cl, ml, options, err := startup.Startup(opts...)
|
||||
app, err := application.New(opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||
}
|
||||
|
||||
appHTTP, err := http.App(cl, ml, options)
|
||||
appHTTP, err := http.API(app)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error during HTTP App construction")
|
||||
return err
|
||||
|
|
|
@ -72,6 +72,8 @@ type BackendConfig struct {
|
|||
|
||||
Description string `yaml:"description"`
|
||||
Usage string `yaml:"usage"`
|
||||
|
||||
Options []string `yaml:"options"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
|
@ -97,16 +99,15 @@ type GRPC struct {
|
|||
}
|
||||
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
}
|
||||
|
||||
// LLMConfig is a struct that holds the configuration that are
|
||||
|
@ -154,8 +155,10 @@ type LLMConfig struct {
|
|||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||
MMProj string `yaml:"mmproj"`
|
||||
|
||||
FlashAttention bool `yaml:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||
FlashAttention bool `yaml:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling"`
|
||||
ModelType string `yaml:"type"`
|
||||
|
@ -164,6 +167,8 @@ type LLMConfig struct {
|
|||
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
|
||||
|
||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
}
|
||||
|
||||
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
|
||||
|
@ -201,6 +206,8 @@ type TemplateConfig struct {
|
|||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||
|
||||
Multimodal string `yaml:"multimodal"`
|
||||
|
||||
JinjaTemplate bool `yaml:"jinja_template"`
|
||||
}
|
||||
|
||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
|
|
|
@ -26,14 +26,14 @@ const (
|
|||
type settingsConfig struct {
|
||||
StopWords []string
|
||||
TemplateConfig TemplateConfig
|
||||
RepeatPenalty float64
|
||||
RepeatPenalty float64
|
||||
}
|
||||
|
||||
// default settings to adopt with a given model family
|
||||
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
||||
Gemma: {
|
||||
RepeatPenalty: 1.0,
|
||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
||||
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
||||
|
@ -200,6 +200,18 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
|
|||
} else {
|
||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
||||
}
|
||||
|
||||
if cfg.HasTemplate() {
|
||||
return
|
||||
}
|
||||
|
||||
// identify from well known templates first, otherwise use the raw jinja template
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found {
|
||||
// try to use the jinja template
|
||||
cfg.TemplateConfig.JinjaTemplate = true
|
||||
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
|
||||
}
|
||||
}
|
||||
|
||||
func identifyFamily(f *gguf.GGUFFile) familyType {
|
||||
|
|
|
@ -14,10 +14,9 @@ import (
|
|||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/contrib/fiberzerolog"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -49,18 +48,18 @@ var embedDirStatic embed.FS
|
|||
// @in header
|
||||
// @name Authorization
|
||||
|
||||
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
||||
func API(application *application.Application) (*fiber.App, error) {
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: true,
|
||||
// Override default error handler
|
||||
}
|
||||
|
||||
if !appConfig.OpaqueErrors {
|
||||
if !application.ApplicationConfig().OpaqueErrors {
|
||||
// Normally, return errors as JSON responses
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
|
@ -86,9 +85,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
}
|
||||
}
|
||||
|
||||
app := fiber.New(fiberCfg)
|
||||
router := fiber.New(fiberCfg)
|
||||
|
||||
app.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
scheme := "http"
|
||||
if listenData.TLS {
|
||||
scheme = "https"
|
||||
|
@ -99,82 +98,82 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
|
||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||
logger := log.Logger
|
||||
app.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||
Logger: &logger,
|
||||
}))
|
||||
|
||||
// Default middleware config
|
||||
|
||||
if !appConfig.Debug {
|
||||
app.Use(recover.New())
|
||||
if !application.ApplicationConfig().Debug {
|
||||
router.Use(recover.New())
|
||||
}
|
||||
|
||||
if !appConfig.DisableMetrics {
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
app.Hooks().OnShutdown(func() error {
|
||||
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
return metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(app)
|
||||
routes.HealthRoutes(router)
|
||||
|
||||
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
|
||||
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil || kaConfig == nil {
|
||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||
app.Use(v2keyauth.New(*kaConfig))
|
||||
router.Use(v2keyauth.New(*kaConfig))
|
||||
|
||||
if appConfig.CORS {
|
||||
if application.ApplicationConfig().CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if appConfig.CORSAllowOrigins == "" {
|
||||
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
||||
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||
}
|
||||
|
||||
app.Use(c)
|
||||
router.Use(c)
|
||||
}
|
||||
|
||||
if appConfig.CSRF {
|
||||
if application.ApplicationConfig().CSRF {
|
||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||
app.Use(csrf.New())
|
||||
router.Use(csrf.New())
|
||||
}
|
||||
|
||||
// Load config jsons
|
||||
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||
utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||
|
||||
galleryService := services.NewGalleryService(appConfig)
|
||||
galleryService.Start(appConfig.Context, cl)
|
||||
galleryService := services.NewGalleryService(application.ApplicationConfig())
|
||||
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
|
||||
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
|
||||
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
|
||||
if !appConfig.DisableWebUI {
|
||||
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
|
||||
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||
routes.RegisterOpenAIRoutes(router, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||
}
|
||||
routes.RegisterJINARoutes(app, cl, ml, appConfig)
|
||||
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
|
||||
app.Use(favicon.New(favicon.Config{
|
||||
router.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.ico",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.ico",
|
||||
}))
|
||||
|
||||
app.Use("/static", filesystem.New(filesystem.Config{
|
||||
router.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
|
@ -182,7 +181,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
app.Use(notFoundHandler)
|
||||
router.Use(notFoundHandler)
|
||||
|
||||
return app, nil
|
||||
return router, nil
|
||||
}
|
||||
|
|
|
@ -5,24 +5,21 @@ import (
|
|||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
@ -254,9 +251,6 @@ var _ = Describe("API test", func() {
|
|||
var cancel context.CancelFunc
|
||||
var tmpdir string
|
||||
var modelDir string
|
||||
var bcl *config.BackendConfigLoader
|
||||
var ml *model.ModelLoader
|
||||
var applicationConfig *config.ApplicationConfig
|
||||
|
||||
commonOpts := []config.AppOption{
|
||||
config.WithDebug(true),
|
||||
|
@ -302,7 +296,7 @@ var _ = Describe("API test", func() {
|
|||
},
|
||||
}
|
||||
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithGalleries(galleries),
|
||||
|
@ -312,7 +306,7 @@ var _ = Describe("API test", func() {
|
|||
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
@ -541,7 +535,7 @@ var _ = Describe("API test", func() {
|
|||
var res map[string]string
|
||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
|
||||
Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res))
|
||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||
|
||||
|
@ -643,7 +637,7 @@ var _ = Describe("API test", func() {
|
|||
},
|
||||
}
|
||||
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithAudioDir(tmpdir),
|
||||
|
@ -654,7 +648,7 @@ var _ = Describe("API test", func() {
|
|||
config.WithBackendAssetsOutput(tmpdir))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
@ -710,7 +704,7 @@ var _ = Describe("API test", func() {
|
|||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
|
||||
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
|
||||
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
|
||||
Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/vnd.wave")))
|
||||
})
|
||||
It("installs and is capable to generate images", Label("stablediffusion"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
|
@ -774,14 +768,14 @@ var _ = Describe("API test", func() {
|
|||
|
||||
var err error
|
||||
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
|
@ -913,71 +907,6 @@ var _ = Describe("API test", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("backends", func() {
|
||||
It("runs rwkv completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
||||
|
||||
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
|
||||
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
tokens := 0
|
||||
text := ""
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
text += response.Choices[0].Text
|
||||
tokens++
|
||||
}
|
||||
Expect(text).ToNot(BeEmpty())
|
||||
Expect(text).To(ContainSubstring("five"))
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
It("runs rwkv chat completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
||||
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
||||
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"), ContainSubstring("5")))
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
tokens := 0
|
||||
text := ""
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
text += response.Choices[0].Delta.Content
|
||||
tokens++
|
||||
}
|
||||
Expect(text).ToNot(BeEmpty())
|
||||
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))
|
||||
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
})
|
||||
|
||||
// See tests/integration/stores_test
|
||||
Context("Stores", Label("stores"), func() {
|
||||
|
||||
|
@ -1057,14 +986,14 @@ var _ = Describe("API test", func() {
|
|||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
|
|
@ -14,6 +14,8 @@ import (
|
|||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
@ -24,7 +26,7 @@ import (
|
|||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
|
@ -294,148 +296,10 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
// If we are using the tokenizer template, we don't need to process the messages
|
||||
// unless we are processing functions
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
|
||||
fcall := i.FunctionCall
|
||||
if len(i.ToolCalls) > 0 {
|
||||
fcall = i.ToolCalls
|
||||
}
|
||||
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
FunctionCall: fcall,
|
||||
FunctionName: i.Name,
|
||||
LastMessage: messageIndex == (len(input.Messages) - 1),
|
||||
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
|
||||
marshalAnyRole := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
marshalAny := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
|
||||
if i.FunctionCall != nil {
|
||||
marshalAnyRole(i.FunctionCall)
|
||||
}
|
||||
if i.ToolCalls != nil {
|
||||
marshalAnyRole(i.ToolCalls)
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
marshalAny(i.FunctionCall)
|
||||
}
|
||||
if i.ToolCalls != nil {
|
||||
marshalAny(i.ToolCalls)
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
joinCharacter := "\n"
|
||||
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
||||
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, joinCharacter)
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !shouldUseFn {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
}
|
||||
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if shouldUseFn && config.Grammar != "" {
|
||||
if config.Grammar != "" {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -25,7 +26,7 @@ import (
|
|||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
|
@ -94,17 +95,6 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
templateFile = config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
|
@ -112,15 +102,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||
Input: predInput,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
@ -165,16 +153,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
||||
for k, i := range config.PromptStrings {
|
||||
if templateFile != "" {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -21,7 +22,8 @@ import (
|
|||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
|
@ -35,31 +37,18 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConf
|
|||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
templateFile = config.TemplateConfig.Edit
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
||||
for _, i := range config.InputStrings {
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
|
|
|
@ -11,62 +11,62 @@ import (
|
|||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterLocalAIRoutes(app *fiber.App,
|
||||
func RegisterLocalAIRoutes(router *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService) {
|
||||
|
||||
app.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||
router.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
|
||||
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||
router.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
}
|
||||
|
||||
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
||||
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
||||
|
||||
// Stores
|
||||
sl := model.NewModelLoader("")
|
||||
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
||||
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
||||
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||
|
||||
if !appConfig.DisableMetrics {
|
||||
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
}
|
||||
|
||||
// Experimental Backend Statistics Module
|
||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
|
||||
// p2p
|
||||
if p2p.IsP2PEnabled() {
|
||||
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||
}
|
||||
|
||||
app.Get("/version", func(c *fiber.Ctx) error {
|
||||
router.Get("/version", func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
app.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||
router.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||
|
||||
// misc
|
||||
app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||
|
||||
}
|
||||
|
|
|
@ -2,84 +2,134 @@ package routes
|
|||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterOpenAIRoutes(app *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
application *application.Application) {
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
||||
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/chat/completions",
|
||||
openai.ChatEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/chat/completions",
|
||||
openai.ChatEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
|
||||
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/edits",
|
||||
openai.EditEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/edits",
|
||||
openai.EditEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
// assistant
|
||||
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
// files
|
||||
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
|
||||
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
|
||||
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
||||
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
||||
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
||||
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||
app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/completions",
|
||||
openai.CompletionEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/completions",
|
||||
openai.CompletionEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
app.Post("/v1/engines/:model/completions",
|
||||
openai.CompletionEndpoint(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
),
|
||||
)
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||
|
||||
if appConfig.ImageDir != "" {
|
||||
app.Static("/generated-images", appConfig.ImageDir)
|
||||
if application.ApplicationConfig().ImageDir != "" {
|
||||
app.Static("/generated-images", application.ApplicationConfig().ImageDir)
|
||||
}
|
||||
|
||||
if appConfig.AudioDir != "" {
|
||||
app.Static("/generated-audio", appConfig.AudioDir)
|
||||
if application.ApplicationConfig().AudioDir != "" {
|
||||
app.Static("/generated-audio", application.ApplicationConfig().AudioDir)
|
||||
}
|
||||
|
||||
// List models
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
|
||||
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||
}
|
||||
|
|
|
@ -194,8 +194,9 @@ diffusers:
|
|||
pipeline_type: StableDiffusionPipeline
|
||||
enable_parameters: "negative_prompt,num_inference_steps,clip_skip"
|
||||
scheduler_type: "k_dpmpp_sde"
|
||||
cfg_scale: 8
|
||||
clip_skip: 11
|
||||
|
||||
cfg_scale: 8
|
||||
```
|
||||
|
||||
#### Configuration parameters
|
||||
|
@ -302,7 +303,8 @@ cuda: true
|
|||
diffusers:
|
||||
pipeline_type: StableDiffusionDepth2ImgPipeline
|
||||
enable_parameters: "negative_prompt,num_inference_steps,image"
|
||||
cfg_scale: 6
|
||||
|
||||
cfg_scale: 6
|
||||
```
|
||||
|
||||
```bash
|
||||
|
|
|
@ -10,13 +10,13 @@ ico = "rocket_launch"
|
|||
For installing LocalAI in Kubernetes, the deployment file from the `examples` can be used and customized as prefered:
|
||||
|
||||
```
|
||||
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI/master/examples/kubernetes/deployment.yaml
|
||||
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI-examples/refs/heads/main/kubernetes/deployment.yaml
|
||||
```
|
||||
|
||||
For Nvidia GPUs:
|
||||
|
||||
```
|
||||
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI/master/examples/kubernetes/deployment-nvidia.yaml
|
||||
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI-examples/refs/heads/main/kubernetes/deployment-nvidia.yaml
|
||||
```
|
||||
|
||||
Alternatively, the [helm chart](https://github.com/go-skynet/helm-charts) can be used as well:
|
||||
|
|
|
@ -6,7 +6,7 @@ weight = 24
|
|||
url = "/model-compatibility/"
|
||||
+++
|
||||
|
||||
Besides llama based models, LocalAI is compatible also with other architectures. The table below lists all the compatible models families and the associated binding repository.
|
||||
Besides llama based models, LocalAI is compatible also with other architectures. The table below lists all the backends, compatible models families and the associated repository.
|
||||
|
||||
{{% alert note %}}
|
||||
|
||||
|
@ -16,19 +16,8 @@ LocalAI will attempt to automatically load models which are not explicitly confi
|
|||
|
||||
| Backend and Bindings | Compatible models | Completion/Chat endpoint | Capability | Embeddings support | Token stream support | Acceleration |
|
||||
|----------------------------------------------------------------------------------|-----------------------|--------------------------|---------------------------|-----------------------------------|----------------------|--------------|
|
||||
| [llama.cpp]({{%relref "docs/features/text-generation#llama.cpp" %}}) | Vicuna, Alpaca, LLaMa, Falcon, Starcoder, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
|
||||
| [gpt4all-llama](https://github.com/nomic-ai/gpt4all) | Vicuna, Alpaca, LLaMa | yes | GPT | no | yes | N/A |
|
||||
| [gpt4all-mpt](https://github.com/nomic-ai/gpt4all) | MPT | yes | GPT | no | yes | N/A |
|
||||
| [gpt4all-j](https://github.com/nomic-ai/gpt4all) | GPT4ALL-J | yes | GPT | no | yes | N/A |
|
||||
| [falcon-ggml](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Falcon (*) | yes | GPT | no | no | N/A |
|
||||
| [dolly](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Dolly | yes | GPT | no | no | N/A |
|
||||
| [gptj](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | GPTJ | yes | GPT | no | no | N/A |
|
||||
| [mpt](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | MPT | yes | GPT | no | no | N/A |
|
||||
| [replit](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Replit | yes | GPT | no | no | N/A |
|
||||
| [gptneox](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | GPT NeoX, RedPajama, StableLM | yes | GPT | no | no | N/A |
|
||||
| [bloomz](https://github.com/NouamaneTazi/bloomz.cpp) ([binding](https://github.com/go-skynet/bloomz.cpp)) | Bloom | yes | GPT | no | no | N/A |
|
||||
| [rwkv](https://github.com/saharNooby/rwkv.cpp) ([binding](https://github.com/donomii/go-rwkv.cpp)) | rwkv | yes | GPT | no | yes | N/A |
|
||||
| [bert](https://github.com/skeskinen/bert.cpp) ([binding](https://github.com/go-skynet/go-bert.cpp)) | bert | no | Embeddings only | yes | no | N/A |
|
||||
| [llama.cpp]({{%relref "docs/features/text-generation#llama.cpp" %}}) | LLama, Mamba, RWKV, Falcon, Starcoder, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
|
||||
| [llama.cpp's ggml model (backward compatibility with old format, before GGUF)](https://github.com/ggerganov/llama.cpp) ([binding](https://github.com/go-skynet/go-llama.cpp)) | LLama, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
|
||||
| [whisper](https://github.com/ggerganov/whisper.cpp) | whisper | no | Audio | no | no | N/A |
|
||||
| [stablediffusion](https://github.com/EdVince/Stable-Diffusion-NCNN) ([binding](https://github.com/mudler/go-stable-diffusion)) | stablediffusion | no | Image | no | no | N/A |
|
||||
| [langchain-huggingface](https://github.com/tmc/langchaingo) | Any text generators available on HuggingFace through API | yes | GPT | no | no | N/A |
|
||||
|
@ -40,11 +29,18 @@ LocalAI will attempt to automatically load models which are not explicitly confi
|
|||
| `diffusers` | SD,... | no | Image generation | no | no | N/A |
|
||||
| `vall-e-x` | Vall-E | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||
| `vllm` | Various GPTs and quantization formats | yes | GPT | no | no | CPU/CUDA |
|
||||
| `mamba` | Mamba models architecture | yes | GPT | no | no | CPU/CUDA |
|
||||
| `exllama2` | GPTQ | yes | GPT only | no | no | N/A |
|
||||
| `transformers-musicgen` | | no | Audio generation | no | no | N/A |
|
||||
| [tinydream](https://github.com/symisc/tiny-dream#tiny-dreaman-embedded-header-only-stable-diffusion-inference-c-librarypixlabiotiny-dream) | stablediffusion | no | Image | no | no | N/A |
|
||||
| `coqui` | Coqui | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||
| `openvoice` | Open voice | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||
| `parler-tts` | Open voice | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||
| [rerankers](https://github.com/AnswerDotAI/rerankers) | Reranking API | no | Reranking | no | no | CPU/CUDA |
|
||||
| `transformers` | Various GPTs and quantization formats | yes | GPT, embeddings | yes | yes**** | CPU/CUDA/XPU |
|
||||
| [bark-cpp](https://github.com/PABannier/bark.cpp) | bark | no | Audio-Only | no | no | yes |
|
||||
| [stablediffusion-cpp](https://github.com/leejet/stable-diffusion.cpp) | stablediffusion-1, stablediffusion-2, stablediffusion-3, flux, PhotoMaker | no | Image | no | no | N/A |
|
||||
| [silero-vad](https://github.com/snakers4/silero-vad) with [Golang bindings](https://github.com/streamer45/silero-vad-go) | Silero VAD | no | Voice Activity Detection | no | no | CPU |
|
||||
|
||||
Note: any backend name listed above can be used in the `backend` field of the model configuration file (See [the advanced section]({{%relref "docs/advanced" %}})).
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
{
|
||||
"version": "v2.23.0"
|
||||
"version": "v2.24.2"
|
||||
}
|
||||
|
|
2
docs/themes/hugo-theme-relearn
vendored
2
docs/themes/hugo-theme-relearn
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 28fce6b04c414523280c53ee02f9f3a94d9d23da
|
||||
Subproject commit bd1f3d3432632c61bb12e7ec0f7673fed0289f19
|
12
gallery/flux-ggml.yaml
Normal file
12
gallery/flux-ggml.yaml
Normal file
|
@ -0,0 +1,12 @@
|
|||
---
|
||||
name: "flux-ggml"
|
||||
|
||||
config_file: |
|
||||
backend: stablediffusion-ggml
|
||||
step: 25
|
||||
options:
|
||||
- "diffusion_model"
|
||||
- "clip_l_path:clip_l.safetensors"
|
||||
- "t5xxl_path:t5xxl_fp16.safetensors"
|
||||
- "vae_path:ae.safetensors"
|
||||
- "sampler:euler"
|
|
@ -11,4 +11,5 @@ config_file: |
|
|||
cuda: true
|
||||
enable_parameters: num_inference_steps
|
||||
pipeline_type: FluxPipeline
|
||||
cfg_scale: 0
|
||||
|
||||
cfg_scale: 0
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -16,6 +16,7 @@ config_file: |
|
|||
|
||||
stopwords:
|
||||
- 'Assistant:'
|
||||
- '<s>'
|
||||
|
||||
template:
|
||||
chat: "{{.Input}}\nAssistant: "
|
||||
|
|
5
go.mod
5
go.mod
|
@ -76,6 +76,7 @@ require (
|
|||
cloud.google.com/go/auth v0.4.1 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
|
@ -84,8 +85,12 @@ require (
|
|||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
||||
github.com/pion/datachannel v1.5.8 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||
github.com/pion/ice/v2 v2.3.34 // indirect
|
||||
|
|
12
go.sum
12
go.sum
|
@ -140,6 +140,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
|
|||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
||||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
||||
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
|
||||
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
||||
|
@ -268,6 +270,7 @@ github.com/google/go-containerregistry v0.19.2 h1:TannFKE1QSajsP6hPWb5oJNgKe1IKj
|
|||
github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI=
|
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
|
@ -353,6 +356,8 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
|||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||
|
@ -474,8 +479,12 @@ github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5
|
|||
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
|
@ -519,6 +528,9 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
|||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
|
||||
github.com/nikolalohinski/gonja/v2 v2.3.2 h1:UgLFfqi7L9XfX0PEcE4eUpvGojVQL5KhBfJJaBp7ZxY=
|
||||
github.com/nikolalohinski/gonja/v2 v2.3.2/go.mod h1:1Wcc/5huTu6y36e0sOFR1XQoFlylw3c3H3L5WOz0RDg=
|
||||
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
|
||||
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
|
||||
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=
|
||||
|
|
|
@ -9,8 +9,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -23,7 +21,6 @@ type ModelLoader struct {
|
|||
ModelPath string
|
||||
mu sync.Mutex
|
||||
models map[string]*Model
|
||||
templates *templates.TemplateCache
|
||||
wd *WatchDog
|
||||
}
|
||||
|
||||
|
@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
|||
nml := &ModelLoader{
|
||||
ModelPath: modelPath,
|
||||
models: make(map[string]*Model),
|
||||
templates: templates.NewTemplateCache(modelPath),
|
||||
}
|
||||
|
||||
return nml
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
)
|
||||
|
||||
// Rather than pass an interface{} to the prompt template:
|
||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||
type PromptTemplateData struct {
|
||||
SystemPrompt string
|
||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||
Input string
|
||||
Instruction string
|
||||
Functions []functions.Function
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
type ChatMessageTemplateData struct {
|
||||
SystemPrompt string
|
||||
Role string
|
||||
RoleName string
|
||||
FunctionName string
|
||||
Content string
|
||||
MessageIndex int
|
||||
Function bool
|
||||
FunctionCall interface{}
|
||||
LastMessage bool
|
||||
}
|
||||
|
||||
const (
|
||||
ChatPromptTemplate templates.TemplateType = iota
|
||||
ChatMessageTemplate
|
||||
CompletionPromptTemplate
|
||||
EditPromptTemplate
|
||||
FunctionsPromptTemplate
|
||||
)
|
||||
|
||||
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||
// TODO: should this check be improved?
|
||||
if templateType == ChatMessageTemplate {
|
||||
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
||||
}
|
||||
return ml.templates.EvaluateTemplate(templateType, templateName, in)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
}
|
|
@ -1,197 +0,0 @@
|
|||
package model_test
|
||||
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||
{{- if .FunctionCall }}
|
||||
<tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
<tool_response>
|
||||
{{- end }}
|
||||
{{- if .Content}}
|
||||
{{.Content }}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall}}
|
||||
{{toJson .FunctionCall}}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall }}
|
||||
</tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
</tool_response>
|
||||
{{- end }}<|im_end|>`
|
||||
|
||||
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
||||
|
||||
{{ if .FunctionCall -}}
|
||||
Function call:
|
||||
{{ else if eq .RoleName "tool" -}}
|
||||
Function response:
|
||||
{{ end -}}
|
||||
{{ if .Content -}}
|
||||
{{.Content -}}
|
||||
{{ else if .FunctionCall -}}
|
||||
{{ toJson .FunctionCall -}}
|
||||
{{ end -}}
|
||||
<|eot_id|>`
|
||||
|
||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "user",
|
||||
RoleName: "user",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_call": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_response": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "tool",
|
||||
RoleName: "tool",
|
||||
Content: "Response from tool",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "user",
|
||||
RoleName: "user",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_call": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_response": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "tool",
|
||||
RoleName: "tool",
|
||||
Content: "Response from tool",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var _ = Describe("Templates", func() {
|
||||
Context("chat message ChatML", func() {
|
||||
var modelLoader *ModelLoader
|
||||
BeforeEach(func() {
|
||||
modelLoader = NewModelLoader("")
|
||||
})
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
Context("chat message llama3", func() {
|
||||
var modelLoader *ModelLoader
|
||||
BeforeEach(func() {
|
||||
modelLoader = NewModelLoader("")
|
||||
})
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
|
@ -11,59 +11,41 @@ import (
|
|||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
|
||||
"github.com/nikolalohinski/gonja/v2"
|
||||
"github.com/nikolalohinski/gonja/v2/exec"
|
||||
)
|
||||
|
||||
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||
type TemplateType int
|
||||
|
||||
type TemplateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
type templateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
jinjaTemplates map[TemplateType]map[string]*exec.Template
|
||||
}
|
||||
|
||||
func NewTemplateCache(templatesPath string) *TemplateCache {
|
||||
tc := &TemplateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
func newTemplateCache(templatesPath string) *templateCache {
|
||||
tc := &templateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
|
||||
}
|
||||
return tc
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.templates[tt]; !ok {
|
||||
tc.templates[tt] = make(map[string]*template.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateName]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateName)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
func (tc *templateCache) existsInModelPath(s string) bool {
|
||||
return utils.ExistsInPath(tc.templatesPath, s)
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.templates[templateType][templateName]; ok {
|
||||
|
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
return fmt.Errorf("template file outside path: %s", file)
|
||||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if tc.existsInModelPath(modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dat = string(d)
|
||||
} else {
|
||||
dat = templateName
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tc.templates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.jinjaTemplates[tt]; !ok {
|
||||
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
// skip any error here - we run anyway if a template does not exist
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||
|
||||
dat := ""
|
||||
file := filepath.Join(tc.templatesPath, modelTemplateFile)
|
||||
|
||||
// Security check
|
||||
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
|
||||
return fmt.Errorf("template file outside path: %s", file)
|
||||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
|
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
dat = templateName
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||
tmpl, err := gonja.FromString(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tc.templates[templateType][templateName] = tmpl
|
||||
tc.jinjaTemplates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeJinjaTemplateMapKey(templateType)
|
||||
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
data := exec.NewContext(in)
|
||||
|
||||
if err := m.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
package templates_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TemplateCache", func() {
|
||||
var (
|
||||
templateCache *templates.TemplateCache
|
||||
tempDir string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "templates")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Writing example template files
|
||||
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
templateCache = templates.NewTemplateCache(tempDir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir) // Clean up
|
||||
})
|
||||
|
||||
Describe("EvaluateTemplate", func() {
|
||||
Context("when template is loaded successfully", func() {
|
||||
It("should evaluate the template correctly", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal("Hello, Gopher!"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template isn't a file", func() {
|
||||
It("should parse from string", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Gopher"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template is empty", func() {
|
||||
It("should return an empty string", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal(""))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("concurrency", func() {
|
||||
It("should handle multiple concurrent accesses", func(done Done) {
|
||||
go func() {
|
||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
}()
|
||||
go func() {
|
||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
}()
|
||||
close(done)
|
||||
}, 0.1) // timeout in seconds
|
||||
})
|
||||
})
|
295
pkg/templates/evaluator.go
Normal file
295
pkg/templates/evaluator.go
Normal file
|
@ -0,0 +1,295 @@
|
|||
package templates
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Rather than pass an interface{} to the prompt template:
|
||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||
type PromptTemplateData struct {
|
||||
SystemPrompt string
|
||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||
Input string
|
||||
Instruction string
|
||||
Functions []functions.Function
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
type ChatMessageTemplateData struct {
|
||||
SystemPrompt string
|
||||
Role string
|
||||
RoleName string
|
||||
FunctionName string
|
||||
Content string
|
||||
MessageIndex int
|
||||
Function bool
|
||||
FunctionCall interface{}
|
||||
LastMessage bool
|
||||
}
|
||||
|
||||
const (
|
||||
ChatPromptTemplate TemplateType = iota
|
||||
ChatMessageTemplate
|
||||
CompletionPromptTemplate
|
||||
EditPromptTemplate
|
||||
FunctionsPromptTemplate
|
||||
)
|
||||
|
||||
type Evaluator struct {
|
||||
cache *templateCache
|
||||
}
|
||||
|
||||
func NewEvaluator(modelPath string) *Evaluator {
|
||||
return &Evaluator{
|
||||
cache: newTemplateCache(modelPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
|
||||
template := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
template = config.Model
|
||||
}
|
||||
|
||||
switch templateType {
|
||||
case CompletionPromptTemplate:
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
template = config.TemplateConfig.Completion
|
||||
}
|
||||
case EditPromptTemplate:
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
template = config.TemplateConfig.Edit
|
||||
}
|
||||
case ChatPromptTemplate:
|
||||
if config.TemplateConfig.Chat != "" {
|
||||
template = config.TemplateConfig.Chat
|
||||
}
|
||||
case FunctionsPromptTemplate:
|
||||
if config.TemplateConfig.Functions != "" {
|
||||
template = config.TemplateConfig.Functions
|
||||
}
|
||||
}
|
||||
|
||||
if template == "" {
|
||||
return in.Input, nil
|
||||
}
|
||||
|
||||
if config.TemplateConfig.JinjaTemplate {
|
||||
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
|
||||
}
|
||||
|
||||
return e.cache.evaluateTemplate(templateType, template, in)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
}
|
||||
|
||||
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
|
||||
|
||||
conversation := make(map[string]interface{})
|
||||
messages := make([]map[string]interface{}, len(messageData))
|
||||
|
||||
// convert from ChatMessageTemplateData to what the jinja template expects
|
||||
|
||||
for _, message := range messageData {
|
||||
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
|
||||
var data []byte
|
||||
data, _ = json.Marshal(message.FunctionCall)
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": message.RoleName,
|
||||
"content": message.Content,
|
||||
"tool_call": string(data),
|
||||
})
|
||||
}
|
||||
|
||||
conversation["messages"] = messages
|
||||
|
||||
// if tools are detected, add these
|
||||
if len(funcs) > 0 {
|
||||
conversation["tools"] = funcs
|
||||
}
|
||||
|
||||
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||
|
||||
conversation := make(map[string]interface{})
|
||||
|
||||
conversation["system_prompt"] = in.SystemPrompt
|
||||
conversation["content"] = in.Input
|
||||
|
||||
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
|
||||
if config.TemplateConfig.JinjaTemplate {
|
||||
var messageData []ChatMessageTemplateData
|
||||
for messageIndex, i := range messages {
|
||||
fcall := i.FunctionCall
|
||||
if len(i.ToolCalls) > 0 {
|
||||
fcall = i.ToolCalls
|
||||
}
|
||||
messageData = append(messageData, ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: config.Roles[i.Role],
|
||||
RoleName: i.Role,
|
||||
Content: i.StringContent,
|
||||
FunctionCall: fcall,
|
||||
FunctionName: i.Name,
|
||||
LastMessage: messageIndex == (len(messages) - 1),
|
||||
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||
MessageIndex: messageIndex,
|
||||
})
|
||||
}
|
||||
|
||||
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
|
||||
if err == nil {
|
||||
return templatedInput
|
||||
}
|
||||
}
|
||||
|
||||
var predInput string
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
|
||||
fcall := i.FunctionCall
|
||||
if len(i.ToolCalls) > 0 {
|
||||
fcall = i.ToolCalls
|
||||
}
|
||||
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
FunctionCall: fcall,
|
||||
FunctionName: i.Name,
|
||||
LastMessage: messageIndex == (len(messages) - 1),
|
||||
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
|
||||
marshalAnyRole := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
marshalAny := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
|
||||
if i.FunctionCall != nil {
|
||||
marshalAnyRole(i.FunctionCall)
|
||||
}
|
||||
if i.ToolCalls != nil {
|
||||
marshalAnyRole(i.ToolCalls)
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
marshalAny(i.FunctionCall)
|
||||
}
|
||||
if i.ToolCalls != nil {
|
||||
marshalAny(i.ToolCalls)
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
joinCharacter := "\n"
|
||||
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
||||
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, joinCharacter)
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
promptTemplate := ChatPromptTemplate
|
||||
|
||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||
promptTemplate = FunctionsPromptTemplate
|
||||
}
|
||||
|
||||
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
|
||||
return predInput
|
||||
}
|
253
pkg/templates/evaluator_test.go
Normal file
253
pkg/templates/evaluator_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package templates_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
. "github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
|
||||
|
||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||
{{- if .FunctionCall }}
|
||||
<tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
<tool_response>
|
||||
{{- end }}
|
||||
{{- if .Content}}
|
||||
{{.Content }}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall}}
|
||||
{{toJson .FunctionCall}}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall }}
|
||||
</tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
</tool_response>
|
||||
{{- end }}<|im_end|>`
|
||||
|
||||
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
||||
|
||||
{{ if .FunctionCall -}}
|
||||
Function call:
|
||||
{{ else if eq .RoleName "tool" -}}
|
||||
Function response:
|
||||
{{ end -}}
|
||||
{{ if .Content -}}
|
||||
{{.Content -}}
|
||||
{{ else if .FunctionCall -}}
|
||||
{{ toJson .FunctionCall -}}
|
||||
{{ end -}}
|
||||
<|eot_id|>`
|
||||
|
||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
"function_call": {
|
||||
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
"function_response": {
|
||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "tool",
|
||||
StringContent: "Response from tool",
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
}
|
||||
|
||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
"function_call": {
|
||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{
|
||||
{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Parameters: nil,
|
||||
},
|
||||
},
|
||||
"shouldUseFn": true,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"function_response": {
|
||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "tool",
|
||||
StringContent: "Response from tool",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: toolCallJinja,
|
||||
JinjaTemplate: true,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
var _ = Describe("Templates", func() {
|
||||
Context("chat message ChatML", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
Context("chat message llama3", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
Context("chat message jinja", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range jinjaTest {
|
||||
foo := jinjaTest[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
|
@ -14,6 +14,7 @@ roles:
|
|||
|
||||
stopwords:
|
||||
- 'Assistant:'
|
||||
- '<s>'
|
||||
|
||||
template:
|
||||
chat: |
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue