Merge branch 'master' into fix/stream_tokens_usage

This commit is contained in:
Ettore Di Giacinto 2024-12-17 09:25:56 +01:00 committed by GitHub
commit 9f6be2be12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
79 changed files with 2716 additions and 941 deletions

5
.github/labeler.yml vendored
View file

@ -1,6 +1,11 @@
enhancements:
- head-branch: ['^feature', 'feature']
dependencies:
- any:
- changed-files:
- any-glob-to-any-file: 'Makefile'
kind/documentation:
- any:
- changed-files:

View file

@ -12,23 +12,14 @@ jobs:
- repository: "ggerganov/llama.cpp"
variable: "CPPLLAMA_VERSION"
branch: "master"
- repository: "go-skynet/go-ggml-transformers.cpp"
variable: "GOGGMLTRANSFORMERS_VERSION"
branch: "master"
- repository: "donomii/go-rwkv.cpp"
variable: "RWKV_VERSION"
branch: "main"
- repository: "ggerganov/whisper.cpp"
variable: "WHISPER_CPP_VERSION"
branch: "master"
- repository: "go-skynet/go-bert.cpp"
variable: "BERT_VERSION"
branch: "master"
- repository: "go-skynet/bloomz.cpp"
variable: "BLOOMZ_VERSION"
- repository: "PABannier/bark.cpp"
variable: "BARKCPP_VERSION"
branch: "main"
- repository: "mudler/go-ggllm.cpp"
variable: "GOGGLLM_VERSION"
- repository: "leejet/stable-diffusion.cpp"
variable: "STABLEDIFFUSION_GGML_VERSION"
branch: "master"
- repository: "mudler/go-stable-diffusion"
variable: "STABLEDIFFUSION_VERSION"

1
.gitignore vendored
View file

@ -2,6 +2,7 @@
/sources/
__pycache__/
*.a
*.o
get-sources
prepare-sources
/backend/cpp/llama/grpc-server

View file

@ -8,7 +8,7 @@ DETECT_LIBS?=true
# llama.cpp versions
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
CPPLLAMA_VERSION?=3ad5451f3b75809e3033e4e577b9f60bcaf6676a
CPPLLAMA_VERSION?=08ea539df211e46bb4d0dd275e541cb591d5ebc8
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
@ -26,6 +26,14 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
# bark.cpp
BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
BARKCPP_VERSION?=v1.0.0
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=9578fdcc4632dc3de5565f28e2fb16b7c18f8d48
ONNX_VERSION?=1.20.0
ONNX_ARCH?=x64
ONNX_OS?=linux
@ -201,6 +209,14 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
ifeq ($(ONNX_OS),linux)
ifeq ($(ONNX_ARCH),x64)
ALL_GRPC_BACKENDS+=backend-assets/grpc/bark-cpp
ALL_GRPC_BACKENDS+=backend-assets/grpc/stablediffusion-ggml
endif
endif
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
@ -236,6 +252,23 @@ sources/go-llama.cpp:
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
## bark.cpp
sources/bark.cpp:
git clone --recursive $(BARKCPP_REPO) sources/bark.cpp && \
cd sources/bark.cpp && \
git checkout $(BARKCPP_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch
sources/bark.cpp/build/libbark.a: sources/bark.cpp
cd sources/bark.cpp && \
mkdir -p build && \
cd build && \
cmake $(CMAKE_ARGS) .. && \
cmake --build . --config Release
backend/go/bark/libbark.a: sources/bark.cpp/build/libbark.a
$(MAKE) -C backend/go/bark libbark.a
## go-piper
sources/go-piper:
mkdir -p sources/go-piper
@ -249,7 +282,7 @@ sources/go-piper:
sources/go-piper/libpiper_binding.a: sources/go-piper
$(MAKE) -C sources/go-piper libpiper_binding.a example/main piper.o
## stable diffusion
## stable diffusion (onnx)
sources/go-stable-diffusion:
mkdir -p sources/go-stable-diffusion
cd sources/go-stable-diffusion && \
@ -262,6 +295,30 @@ sources/go-stable-diffusion:
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
## stablediffusion (ggml)
sources/stablediffusion-ggml.cpp:
git clone --recursive $(STABLEDIFFUSION_GGML_REPO) sources/stablediffusion-ggml.cpp && \
cd sources/stablediffusion-ggml.cpp && \
git checkout $(STABLEDIFFUSION_GGML_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch
sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a: sources/stablediffusion-ggml.cpp
cd sources/stablediffusion-ggml.cpp && \
mkdir -p build && \
cd build && \
cmake $(CMAKE_ARGS) .. && \
cmake --build . --config Release
backend/go/image/stablediffusion-ggml/libsd.a: sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a
$(MAKE) -C backend/go/image/stablediffusion-ggml libsd.a
backend-assets/grpc/stablediffusion-ggml: backend/go/image/stablediffusion-ggml/libsd.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/image/stablediffusion-ggml/ LIBRARY_PATH=$(CURDIR)/backend/go/image/stablediffusion-ggml/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion-ggml ./backend/go/image/stablediffusion-ggml/
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/stablediffusion-ggml
endif
sources/onnxruntime:
mkdir -p sources/onnxruntime
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
@ -302,7 +359,7 @@ sources/whisper.cpp:
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a
get-sources: sources/go-llama.cpp sources/go-piper sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
get-sources: sources/go-llama.cpp sources/go-piper sources/stablediffusion-ggml.cpp sources/bark.cpp sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
replace:
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
@ -343,7 +400,9 @@ clean: ## Remove build related file
rm -rf release/
rm -rf backend-assets/*
$(MAKE) -C backend/cpp/grpc clean
$(MAKE) -C backend/go/bark clean
$(MAKE) -C backend/cpp/llama clean
$(MAKE) -C backend/go/image/stablediffusion-ggml clean
rm -rf backend/cpp/llama-* || true
$(MAKE) dropreplace
$(MAKE) protogen-clean
@ -792,6 +851,13 @@ ifneq ($(UPX),)
$(UPX) backend-assets/grpc/llama-ggml
endif
backend-assets/grpc/bark-cpp: backend/go/bark/libbark.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/bark/ LIBRARY_PATH=$(CURDIR)/backend/go/bark/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bark-cpp ./backend/go/bark/
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/bark-cpp
endif
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/

View file

@ -92,6 +92,8 @@ local-ai run oci://localai/phi-2:latest
## 📰 Latest project news
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)

View file

@ -240,6 +240,11 @@ message ModelOptions {
repeated string LoraAdapters = 60;
repeated float LoraScales = 61;
repeated string Options = 62;
string CacheTypeKey = 63;
string CacheTypeValue = 64;
}
message Result {

View file

@ -681,7 +681,6 @@ struct llama_server_context
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
@ -1213,13 +1212,12 @@ struct llama_server_context
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt},
{"n_predict", slot.params.n_predict},
{"n_keep", params.n_keep},
{"ignore_eos", slot.sparams.ignore_eos},
{"stream", slot.params.stream},
// {"logit_bias", slot.sparams.logit_bias},
// {"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
@ -2112,7 +2110,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
// slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
// slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
// slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
// slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
// slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
// slot->params.seed = json_value(data, "seed", default_params.seed);
// slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
@ -2135,7 +2132,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
data["mirostat"] = predict->mirostat();
data["mirostat_tau"] = predict->mirostattau();
data["mirostat_eta"] = predict->mirostateta();
data["penalize_nl"] = predict->penalizenl();
data["n_keep"] = predict->nkeep();
data["seed"] = predict->seed();
data["grammar"] = predict->grammar();
@ -2181,7 +2177,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
// llama.params.sparams.mirostat = predict->mirostat();
// llama.params.sparams.mirostat_tau = predict->mirostattau();
// llama.params.sparams.mirostat_eta = predict->mirostateta();
// llama.params.sparams.penalize_nl = predict->penalizenl();
// llama.params.n_keep = predict->nkeep();
// llama.params.seed = predict->seed();
// llama.params.sparams.grammar = predict->grammar();
@ -2228,6 +2223,35 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
// }
// }
const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_BF16,
GGML_TYPE_Q8_0,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
};
static ggml_type kv_cache_type_from_str(const std::string & s) {
for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) {
return type;
}
}
throw std::runtime_error("Unsupported cache type: " + s);
}
static std::string get_all_kv_cache_types() {
std::ostringstream msg;
for (const auto & type : kv_cache_types) {
msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", ");
}
return msg.str();
}
static void params_parse(const backend::ModelOptions* request,
common_params & params) {
@ -2241,6 +2265,12 @@ static void params_parse(const backend::ModelOptions* request,
}
// params.model_alias ??
params.model_alias = request->modelfile();
if (!request->cachetypekey().empty()) {
params.cache_type_k = kv_cache_type_from_str(request->cachetypekey());
}
if (!request->cachetypevalue().empty()) {
params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue());
}
params.n_ctx = request->contextsize();
//params.memory_f16 = request->f16memory();
params.cpuparams.n_threads = request->threads();

25
backend/go/bark/Makefile Normal file
View file

@ -0,0 +1,25 @@
INCLUDE_PATH := $(abspath ./)
LIBRARY_PATH := $(abspath ./)
AR?=ar
BUILD_TYPE?=
# keep standard at C11 and C++11
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm
# warnings
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
gobark.o:
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)
libbark.a: gobark.o
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./
$(AR) rcs libbark.a gobark.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o
clean:
rm -f gobark.o libbark.a

View file

@ -0,0 +1,85 @@
#include <iostream>
#include <tuple>
#include "bark.h"
#include "gobark.h"
#include "common.h"
#include "ggml.h"
struct bark_context *c;
void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
if (step == bark_encoding_step::SEMANTIC) {
printf("\rGenerating semantic tokens... %d%%", progress);
} else if (step == bark_encoding_step::COARSE) {
printf("\rGenerating coarse tokens... %d%%", progress);
} else if (step == bark_encoding_step::FINE) {
printf("\rGenerating fine tokens... %d%%", progress);
}
fflush(stdout);
}
int load_model(char *model) {
// initialize bark context
struct bark_context_params ctx_params = bark_context_default_params();
bark_params params;
params.model_path = model;
// ctx_params.verbosity = verbosity;
ctx_params.progress_callback = bark_print_progress_callback;
ctx_params.progress_callback_user_data = nullptr;
struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
if (!bctx) {
fprintf(stderr, "%s: Could not load model\n", __func__);
return 1;
}
c = bctx;
return 0;
}
int tts(char *text,int threads, char *dst ) {
ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();
// generate audio
if (!bark_generate_audio(c, text, threads)) {
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
return 1;
}
const float *audio_data = bark_get_audio_data(c);
if (audio_data == NULL) {
fprintf(stderr, "%s: Could not get audio data\n", __func__);
return 1;
}
const int audio_arr_size = bark_get_audio_data_size(c);
std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);
write_wav_on_disk(audio_arr, dst);
// report timing
{
const int64_t t_main_end_us = ggml_time_us();
const int64_t t_load_us = bark_get_load_time(c);
const int64_t t_eval_us = bark_get_eval_time(c);
printf("\n\n");
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
}
return 0;
}
int unload() {
bark_free(c);
}

52
backend/go/bark/gobark.go Normal file
View file

@ -0,0 +1,52 @@
package main
// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon
// #include <gobark.h>
// #include <stdlib.h>
import "C"
import (
"fmt"
"unsafe"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
type Bark struct {
base.SingleThread
threads int
}
func (sd *Bark) Load(opts *pb.ModelOptions) error {
sd.threads = int(opts.Threads)
modelFile := C.CString(opts.ModelFile)
defer C.free(unsafe.Pointer(modelFile))
ret := C.load_model(modelFile)
if ret != 0 {
return fmt.Errorf("inference failed")
}
return nil
}
func (sd *Bark) TTS(opts *pb.TTSRequest) error {
t := C.CString(opts.Text)
defer C.free(unsafe.Pointer(t))
dst := C.CString(opts.Dst)
defer C.free(unsafe.Pointer(dst))
threads := C.int(sd.threads)
ret := C.tts(t, threads, dst)
if ret != 0 {
return fmt.Errorf("inference failed")
}
return nil
}

8
backend/go/bark/gobark.h Normal file
View file

@ -0,0 +1,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int load_model(char *model);
int tts(char *text,int threads, char *dst );
#ifdef __cplusplus
}
#endif

20
backend/go/bark/main.go Normal file
View file

@ -0,0 +1,20 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &Bark{}); err != nil {
panic(err)
}
}

View file

@ -0,0 +1,21 @@
INCLUDE_PATH := $(abspath ./)
LIBRARY_PATH := $(abspath ./)
AR?=ar
BUILD_TYPE?=
# keep standard at C11 and C++11
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/thirdparty -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/ggml/include -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp -O3 -DNDEBUG -std=c++17 -fPIC
# warnings
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
gosd.o:
$(CXX) $(CXXFLAGS) gosd.cpp -o gosd.o -c
libsd.a: gosd.o
cp $(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a ./libsd.a
$(AR) rcs libsd.a gosd.o
clean:
rm -f gosd.o libsd.a

View file

@ -0,0 +1,228 @@
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "gosd.h"
// #include "preprocessing.hpp"
#include "flux.hpp"
#include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h"
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
"euler_a",
"euler",
"heun",
"dpm2",
"dpm++2s_a",
"dpm++2m",
"dpm++2mv2",
"ipndm",
"ipndm_v",
"lcm",
};
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedule_str[] = {
"default",
"discrete",
"karras",
"exponential",
"ays",
"gits",
};
sd_ctx_t* sd_c;
sample_method_t sample_method;
int load_model(char *model, char* options[], int threads, int diff) {
fprintf (stderr, "Loading model!\n");
char *stableDiffusionModel = "";
if (diff == 1 ) {
stableDiffusionModel = model;
model = "";
}
// decode options. Options are in form optname:optvale, or if booleans only optname.
char *clip_l_path = "";
char *clip_g_path = "";
char *t5xxl_path = "";
char *vae_path = "";
char *scheduler = "";
char *sampler = "";
// If options is not NULL, parse options
for (int i = 0; options[i] != NULL; i++) {
char *optname = strtok(options[i], ":");
char *optval = strtok(NULL, ":");
if (optval == NULL) {
optval = "true";
}
if (!strcmp(optname, "clip_l_path")) {
clip_l_path = optval;
}
if (!strcmp(optname, "clip_g_path")) {
clip_g_path = optval;
}
if (!strcmp(optname, "t5xxl_path")) {
t5xxl_path = optval;
}
if (!strcmp(optname, "vae_path")) {
vae_path = optval;
}
if (!strcmp(optname, "scheduler")) {
scheduler = optval;
}
if (!strcmp(optname, "sampler")) {
sampler = optval;
}
}
int sample_method_found = -1;
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
if (!strcmp(sampler, sample_method_str[m])) {
sample_method_found = m;
}
}
if (sample_method_found == -1) {
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
sample_method_found = EULER_A;
}
sample_method = (sample_method_t)sample_method_found;
int schedule_found = -1;
for (int d = 0; d < N_SCHEDULES; d++) {
if (!strcmp(scheduler, schedule_str[d])) {
schedule_found = d;
fprintf (stderr, "Found scheduler: %s\n", scheduler);
}
}
if (schedule_found == -1) {
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
schedule_found = DEFAULT;
}
schedule_t schedule = (schedule_t)schedule_found;
fprintf (stderr, "Creating context\n");
sd_ctx_t* sd_ctx = new_sd_ctx(model,
clip_l_path,
clip_g_path,
t5xxl_path,
stableDiffusionModel,
vae_path,
"",
"",
"",
"",
"",
false,
false,
false,
threads,
SD_TYPE_COUNT,
STD_DEFAULT_RNG,
schedule,
false,
false,
false,
false);
if (sd_ctx == NULL) {
fprintf (stderr, "failed loading model (generic error)\n");
return 1;
}
fprintf (stderr, "Created context: OK\n");
sd_c = sd_ctx;
return 0;
}
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) {
sd_image_t* results;
std::vector<int> skip_layers = {7, 8, 9};
fprintf (stderr, "Generating image\n");
results = txt2img(sd_c,
text,
negativeText,
-1, //clip_skip
cfg_scale, // sfg_scale
3.5f,
width,
height,
sample_method,
steps,
seed,
1,
NULL,
0.9f,
20.f,
false,
"",
skip_layers.data(),
skip_layers.size(),
0,
0.01,
0.2);
if (results == NULL) {
fprintf (stderr, "NO results\n");
return 1;
}
if (results[0].data == NULL) {
fprintf (stderr, "Results with no data\n");
return 1;
}
fprintf (stderr, "Writing PNG\n");
fprintf (stderr, "DST: %s\n", dst);
fprintf (stderr, "Width: %d\n", results[0].width);
fprintf (stderr, "Height: %d\n", results[0].height);
fprintf (stderr, "Channel: %d\n", results[0].channel);
fprintf (stderr, "Data: %p\n", results[0].data);
stbi_write_png(dst, results[0].width, results[0].height, results[0].channel,
results[0].data, 0, NULL);
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
// TODO: free results. Why does it crash?
free(results[0].data);
results[0].data = NULL;
free(results);
fprintf (stderr, "gen_image is done", dst);
return 0;
}
int unload() {
free_sd_ctx(sd_c);
}

View file

@ -0,0 +1,96 @@
package main
// #cgo CXXFLAGS: -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/thirdparty -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/ggml/include
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/build/ggml/src/ggml-cpu -L${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/build/ggml/src -lsd -lstdc++ -lm -lggml -lggml-base -lggml-cpu -lgomp
// #include <gosd.h>
// #include <stdlib.h>
import "C"
import (
"fmt"
"os"
"path/filepath"
"strings"
"unsafe"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)
type SDGGML struct {
base.SingleThread
threads int
sampleMethod string
cfgScale float32
}
func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
sd.threads = int(opts.Threads)
modelFile := C.CString(opts.ModelFile)
defer C.free(unsafe.Pointer(modelFile))
var options **C.char
// prepare the options array to pass to C
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
length := C.size_t(len(opts.Options))
options = (**C.char)(C.malloc(length * size))
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0:len(opts.Options):len(opts.Options)]
var diffusionModel int
var oo []string
for _, op := range opts.Options {
if op == "diffusion_model" {
diffusionModel = 1
continue
}
// If it's an option path, we resolve absolute path from the model path
if strings.Contains(op, ":") && strings.Contains(op, "path") {
data := strings.Split(op, ":")
data[1] = filepath.Join(opts.ModelPath, data[1])
if err := utils.VerifyPath(data[1], opts.ModelPath); err == nil {
oo = append(oo, strings.Join(data, ":"))
}
} else {
oo = append(oo, op)
}
}
fmt.Fprintf(os.Stderr, "Options: %+v\n", oo)
for i, x := range oo {
view[i] = C.CString(x)
}
sd.cfgScale = opts.CFGScale
ret := C.load_model(modelFile, options, C.int(opts.Threads), C.int(diffusionModel))
if ret != 0 {
return fmt.Errorf("could not load model")
}
return nil
}
func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
t := C.CString(opts.PositivePrompt)
defer C.free(unsafe.Pointer(t))
dst := C.CString(opts.Dst)
defer C.free(unsafe.Pointer(dst))
negative := C.CString(opts.NegativePrompt)
defer C.free(unsafe.Pointer(negative))
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale))
if ret != 0 {
return fmt.Errorf("inference failed")
}
return nil
}

View file

@ -0,0 +1,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int load_model(char *model, char* options[], int threads, int diffusionModel);
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale);
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,20 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
grpc "github.com/mudler/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &SDGGML{}); err != nil {
panic(err)
}
}

View file

@ -2,4 +2,4 @@
intel-extension-for-pytorch
torch
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools

View file

@ -1,6 +1,6 @@
accelerate
auto-gptq==0.7.1
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
transformers

View file

@ -3,6 +3,6 @@ intel-extension-for-pytorch
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools
transformers
accelerate

View file

@ -1,4 +1,4 @@
bark==0.1.5
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi

View file

@ -17,6 +17,9 @@
# LIMIT_TARGETS="cublas12"
# source $(dirname $0)/../common/libbackend.sh
#
PYTHON_VERSION="3.10"
function init() {
# Name of the backend (directory name)
BACKEND_NAME=${PWD##*/}
@ -88,7 +91,7 @@ function getBuildProfile() {
# always result in an activated virtual environment
function ensureVenv() {
if [ ! -d "${EDIR}/venv" ]; then
uv venv ${EDIR}/venv
uv venv --python ${PYTHON_VERSION} ${EDIR}/venv
echo "virtualenv created"
fi

View file

@ -1,3 +1,3 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
grpcio-tools

View file

@ -3,7 +3,7 @@ intel-extension-for-pytorch
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools
transformers
accelerate
coqui-tts

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
packaging==24.1

View file

@ -3,7 +3,7 @@ intel-extension-for-pytorch
torch
torchvision
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools
diffusers
opencv-python
transformers

View file

@ -1,5 +1,5 @@
setuptools
grpcio==1.68.0
grpcio==1.68.1
pillow
protobuf
certifi

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
wheel

View file

@ -1,3 +1,3 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi

View file

@ -2,7 +2,7 @@
intel-extension-for-pytorch
torch
optimum[openvino]
grpcio==1.68.0
grpcio==1.68.1
protobuf
librosa==0.9.1
faster-whisper==0.9.0

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
librosa
faster-whisper
@ -18,3 +18,4 @@ jieba==0.42.1
gradio==3.48.0
langid==1.1.6
llvmlite==0.43.0
setuptools

View file

@ -3,6 +3,5 @@ intel-extension-for-pytorch
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
transformers
accelerate

View file

@ -1,3 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
certifi
llvmlite==0.43.0
setuptools

View file

@ -5,4 +5,4 @@ accelerate
torch
rerankers[transformers]
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools

View file

@ -1,3 +1,3 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi

View file

@ -2,7 +2,7 @@
intel-extension-for-pytorch
torch
optimum[openvino]
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
setuptools
accelerate
sentence-transformers==3.3.1
transformers

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
datasets

View file

@ -4,4 +4,4 @@ transformers
accelerate
torch
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
scipy==1.14.0
certifi

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
setuptools

View file

@ -3,5 +3,4 @@ intel-extension-for-pytorch
accelerate
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
optimum[openvino]

View file

@ -1,3 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
certifi
setuptools

View file

@ -22,7 +22,7 @@ if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE}" == "xtrue" ]; then
git clone https://github.com/vllm-project/vllm
fi
pushd vllm
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.0 protobuf bitsandbytes
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.1 protobuf bitsandbytes
uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
VLLM_TARGET_DEVICE=cpu python setup.py install
popd

View file

@ -4,5 +4,5 @@ accelerate
torch
transformers
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools
bitsandbytes

View file

@ -1,4 +1,4 @@
grpcio==1.68.0
grpcio==1.68.1
protobuf
certifi
setuptools

View file

@ -1,38 +0,0 @@
package core
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
)
// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy
// Perhaps a proper DI system is worth it in the future, but for now keep things simple.
type Application struct {
// Application-Level Config
ApplicationConfig *config.ApplicationConfig
// ApplicationState *ApplicationState
// Core Low-Level Services
BackendConfigLoader *config.BackendConfigLoader
ModelLoader *model.ModelLoader
// Backend Services
// EmbeddingsBackendService *backend.EmbeddingsBackendService
// ImageGenerationBackendService *backend.ImageGenerationBackendService
// LLMBackendService *backend.LLMBackendService
// TranscriptionBackendService *backend.TranscriptionBackendService
// TextToSpeechBackendService *backend.TextToSpeechBackendService
// LocalAI System Services
BackendMonitorService *services.BackendMonitorService
GalleryService *services.GalleryService
LocalAIMetricsService *services.LocalAIMetricsService
// OpenAIService *services.OpenAIService
}
// TODO [NEXT PR?]: Break up ApplicationConfig.
// Migrate over stuff that is not set via config at all - especially runtime stuff
type ApplicationState struct {
}

View file

@ -0,0 +1,39 @@
package application
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
)
type Application struct {
backendLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
templatesEvaluator *templates.Evaluator
}
func newApplication(appConfig *config.ApplicationConfig) *Application {
return &Application{
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
modelLoader: model.NewModelLoader(appConfig.ModelPath),
applicationConfig: appConfig,
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
}
}
func (a *Application) BackendLoader() *config.BackendConfigLoader {
return a.backendLoader
}
func (a *Application) ModelLoader() *model.ModelLoader {
return a.modelLoader
}
func (a *Application) ApplicationConfig() *config.ApplicationConfig {
return a.applicationConfig
}
func (a *Application) TemplatesEvaluator() *templates.Evaluator {
return a.templatesEvaluator
}

View file

@ -1,4 +1,4 @@
package startup
package application
import (
"encoding/json"
@ -8,8 +8,8 @@ import (
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
"dario.cat/mergo"
"github.com/fsnotify/fsnotify"
"github.com/mudler/LocalAI/core/config"
"github.com/rs/zerolog/log"
)

View file

@ -1,15 +1,15 @@
package startup
package application
import (
"fmt"
"os"
"github.com/mudler/LocalAI/core"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/assets"
"github.com/mudler/LocalAI/pkg/library"
"github.com/mudler/LocalAI/pkg/model"
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
@ -17,8 +17,9 @@ import (
"github.com/rs/zerolog/log"
)
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
func New(opts ...config.AppOption) (*Application, error) {
options := config.NewApplicationConfig(opts...)
application := newApplication(options)
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
@ -36,28 +37,28 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
// Make sure directories exists
if options.ModelPath == "" {
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
return nil, fmt.Errorf("options.ModelPath cannot be empty")
}
err = os.MkdirAll(options.ModelPath, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
}
if options.ImageDir != "" {
err := os.MkdirAll(options.ImageDir, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
}
}
if options.AudioDir != "" {
err := os.MkdirAll(options.AudioDir, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
}
}
if options.UploadDir != "" {
err := os.MkdirAll(options.UploadDir, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
}
}
@ -65,39 +66,36 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
log.Error().Err(err).Msg("error installing models")
}
cl := config.NewBackendConfigLoader(options.ModelPath)
ml := model.NewModelLoader(options.ModelPath)
configLoaderOpts := options.ToConfigLoaderOptions()
if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config files")
}
if options.ConfigFile != "" {
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config file")
}
}
if err := cl.Preload(options.ModelPath); err != nil {
if err := application.BackendLoader().Preload(options.ModelPath); err != nil {
log.Error().Err(err).Msg("error downloading models")
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil {
return nil, nil, nil, err
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil {
return nil, nil, nil, err
return nil, err
}
}
if options.Debug {
for _, v := range cl.GetAllBackendConfigs() {
for _, v := range application.BackendLoader().GetAllBackendConfigs() {
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
}
}
@ -123,7 +121,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
err := ml.StopAllGRPC()
err := application.ModelLoader().StopAllGRPC()
if err != nil {
log.Error().Err(err).Msg("error while stopping all grpc backends")
}
@ -131,12 +129,12 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
if options.WatchDog {
wd := model.NewWatchDog(
ml,
application.ModelLoader(),
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
ml.SetWatchDog(wd)
application.ModelLoader().SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
@ -147,7 +145,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
if options.LoadToMemory != nil {
for _, m := range options.LoadToMemory {
cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath,
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
config.LoadOptionDebug(options.Debug),
config.LoadOptionThreads(options.Threads),
config.LoadOptionContextSize(options.ContextSize),
@ -155,7 +153,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
config.ModelPath(options.ModelPath),
)
if err != nil {
return nil, nil, nil, err
return nil, err
}
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
@ -163,9 +161,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
o := backend.ModelOptions(*cfg, options)
var backendErr error
_, backendErr = ml.Load(o...)
_, backendErr = application.ModelLoader().Load(o...)
if backendErr != nil {
return nil, nil, nil, err
return nil, err
}
}
}
@ -174,7 +172,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
startWatcher(options)
log.Info().Msg("core/startup process completed!")
return cl, ml, options, nil
return application, nil
}
func startWatcher(options *config.ApplicationConfig) {
@ -201,32 +199,3 @@ func startWatcher(options *config.ApplicationConfig) {
log.Error().Err(err).Msg("failed creating watcher")
}
}
// In Lieu of a proper DI framework, this function wires up the Application manually.
// This is in core/startup rather than core/state.go to keep package references clean!
func createApplication(appConfig *config.ApplicationConfig) *core.Application {
app := &core.Application{
ApplicationConfig: appConfig,
BackendConfigLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
}
var err error
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.GalleryService = services.NewGalleryService(app.ApplicationConfig)
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()
if err != nil {
log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.")
}
return app
}

View file

@ -122,7 +122,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
CFGScale: c.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraAdapters: c.LoraAdapters,
@ -132,6 +132,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
Options: c.Options,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(ctxSize),
@ -150,6 +151,8 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
TensorParallelSize: int32(c.TensorParallelSize),
MMProj: c.MMProj,
FlashAttention: c.FlashAttention,
CacheTypeKey: c.CacheTypeK,
CacheTypeValue: c.CacheTypeV,
NoKVOffload: c.NoKVOffloading,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,

View file

@ -6,12 +6,12 @@ import (
"strings"
"time"
"github.com/mudler/LocalAI/core/application"
cli_api "github.com/mudler/LocalAI/core/cli/api"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/startup"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
@ -186,16 +186,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}
if r.PreloadBackendOnly {
_, _, _, err := startup.Startup(opts...)
_, err := application.New(opts...)
return err
}
cl, ml, options, err := startup.Startup(opts...)
app, err := application.New(opts...)
if err != nil {
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
appHTTP, err := http.App(cl, ml, options)
appHTTP, err := http.API(app)
if err != nil {
log.Error().Err(err).Msg("error during HTTP App construction")
return err

View file

@ -72,6 +72,8 @@ type BackendConfig struct {
Description string `yaml:"description"`
Usage string `yaml:"usage"`
Options []string `yaml:"options"`
}
type File struct {
@ -97,16 +99,15 @@ type GRPC struct {
}
type Diffusers struct {
CUDA bool `yaml:"cuda"`
PipelineType string `yaml:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
CUDA bool `yaml:"cuda"`
PipelineType string `yaml:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
}
// LLMConfig is a struct that holds the configuration that are
@ -154,8 +155,10 @@ type LLMConfig struct {
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`
FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v"`
RopeScaling string `yaml:"rope_scaling"`
ModelType string `yaml:"type"`
@ -164,6 +167,8 @@ type LLMConfig struct {
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
}
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
@ -201,6 +206,8 @@ type TemplateConfig struct {
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
Multimodal string `yaml:"multimodal"`
JinjaTemplate bool `yaml:"jinja_template"`
}
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {

View file

@ -26,14 +26,14 @@ const (
type settingsConfig struct {
StopWords []string
TemplateConfig TemplateConfig
RepeatPenalty float64
RepeatPenalty float64
}
// default settings to adopt with a given model family
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
Gemma: {
RepeatPenalty: 1.0,
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
TemplateConfig: TemplateConfig{
Chat: "{{.Input }}\n<start_of_turn>model\n",
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
@ -200,6 +200,18 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
} else {
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
}
if cfg.HasTemplate() {
return
}
// identify from well known templates first, otherwise use the raw jinja template
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
if found {
// try to use the jinja template
cfg.TemplateConfig.JinjaTemplate = true
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
}
}
func identifyFamily(f *gguf.GGUFFile) familyType {

View file

@ -14,10 +14,9 @@ import (
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
@ -49,18 +48,18 @@ var embedDirStatic embed.FS
// @in header
// @name Authorization
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
func API(application *application.Application) (*fiber.App, error) {
fiberCfg := fiber.Config{
Views: renderEngine(),
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
}
if !appConfig.OpaqueErrors {
if !application.ApplicationConfig().OpaqueErrors {
// Normally, return errors as JSON responses
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
@ -86,9 +85,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
}
}
app := fiber.New(fiberCfg)
router := fiber.New(fiberCfg)
app.Hooks().OnListen(func(listenData fiber.ListenData) error {
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
@ -99,82 +98,82 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
app.Use(fiberzerolog.New(fiberzerolog.Config{
router.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))
// Default middleware config
if !appConfig.Debug {
app.Use(recover.New())
if !application.ApplicationConfig().Debug {
router.Use(recover.New())
}
if !appConfig.DisableMetrics {
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
return nil, err
}
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
router.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(app)
routes.HealthRoutes(router)
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil || kaConfig == nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
app.Use(v2keyauth.New(*kaConfig))
router.Use(v2keyauth.New(*kaConfig))
if appConfig.CORS {
if application.ApplicationConfig().CORS {
var c func(ctx *fiber.Ctx) error
if appConfig.CORSAllowOrigins == "" {
if application.ApplicationConfig().CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
}
app.Use(c)
router.Use(c)
}
if appConfig.CSRF {
if application.ApplicationConfig().CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
app.Use(csrf.New())
router.Use(csrf.New())
}
// Load config jsons
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
galleryService := services.NewGalleryService(appConfig)
galleryService.Start(appConfig.Context, cl)
galleryService := services.NewGalleryService(application.ApplicationConfig())
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
if !appConfig.DisableWebUI {
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
}
routes.RegisterJINARoutes(app, cl, ml, appConfig)
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
httpFS := http.FS(embedDirStatic)
app.Use(favicon.New(favicon.Config{
router.Use(favicon.New(favicon.Config{
URL: "/favicon.ico",
FileSystem: httpFS,
File: "static/favicon.ico",
}))
app.Use("/static", filesystem.New(filesystem.Config{
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
@ -182,7 +181,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// Define a custom 404 handler
// Note: keep this at the bottom!
app.Use(notFoundHandler)
router.Use(notFoundHandler)
return app, nil
return router, nil
}

View file

@ -5,24 +5,21 @@ import (
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/startup"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
@ -254,9 +251,6 @@ var _ = Describe("API test", func() {
var cancel context.CancelFunc
var tmpdir string
var modelDir string
var bcl *config.BackendConfigLoader
var ml *model.ModelLoader
var applicationConfig *config.ApplicationConfig
commonOpts := []config.AppOption{
config.WithDebug(true),
@ -302,7 +296,7 @@ var _ = Describe("API test", func() {
},
}
bcl, ml, applicationConfig, err = startup.Startup(
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithGalleries(galleries),
@ -312,7 +306,7 @@ var _ = Describe("API test", func() {
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -541,7 +535,7 @@ var _ = Describe("API test", func() {
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
@ -643,7 +637,7 @@ var _ = Describe("API test", func() {
},
}
bcl, ml, applicationConfig, err = startup.Startup(
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithAudioDir(tmpdir),
@ -654,7 +648,7 @@ var _ = Describe("API test", func() {
config.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -710,7 +704,7 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/vnd.wave")))
})
It("installs and is capable to generate images", Label("stablediffusion"), func() {
if runtime.GOOS != "linux" {
@ -774,14 +768,14 @@ var _ = Describe("API test", func() {
var err error
bcl, ml, applicationConfig, err = startup.Startup(
application, err := application.New(
append(commonOpts,
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
config.WithContext(c),
config.WithModelPath(modelPath),
)...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -913,71 +907,6 @@ var _ = Describe("API test", func() {
})
})
Context("backends", func() {
It("runs rwkv completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"), ContainSubstring("5")))
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Delta.Content
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})
// See tests/integration/stores_test
Context("Stores", Label("stores"), func() {
@ -1057,14 +986,14 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background())
var err error
bcl, ml, applicationConfig, err = startup.Startup(
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithModelPath(modelPath),
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")

View file

@ -14,6 +14,8 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/templates"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
@ -24,7 +26,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
var id, textContentToReturn string
var created int
@ -294,148 +296,10 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(input.Messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAnyRole(i.ToolCalls)
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
joinCharacter := "\n"
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
}
predInput = strings.Join(mess, joinCharacter)
log.Debug().Msgf("Prompt (before templating): %s", predInput)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !shouldUseFn {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && shouldUseFn {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if shouldUseFn && config.Grammar != "" {
if config.Grammar != "" {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}

View file

@ -16,6 +16,7 @@ import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
@ -25,7 +26,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
@ -94,17 +95,6 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
}
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
@ -112,15 +102,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
predInput := config.PromptStrings[0]
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
Input: predInput,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
responses := make(chan schema.OpenAIResponse)
@ -165,16 +153,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
totalTokenUsage := backend.TokenUsage{}
for k, i := range config.PromptStrings {
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(

View file

@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
)
@ -21,7 +22,8 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/edits [post]
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
@ -35,31 +37,18 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConf
log.Debug().Msgf("Parameter Config: %+v", config)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {

View file

@ -11,62 +11,62 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterLocalAIRoutes(app *fiber.App,
func RegisterLocalAIRoutes(router *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService) {
app.Get("/swagger/*", swagger.HandlerDefault) // default
router.Get("/swagger/*", swagger.HandlerDefault) // default
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
router.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
router.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
}
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
if !appConfig.DisableMetrics {
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
}
// Experimental Backend Statistics Module
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
if p2p.IsP2PEnabled() {
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
}
app.Get("/version", func(c *fiber.Ctx) error {
router.Get("/version", func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
app.Get("/system", localai.SystemInformations(ml, appConfig))
router.Get("/system", localai.SystemInformations(ml, appConfig))
// misc
app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
}

View file

@ -2,84 +2,134 @@ package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterOpenAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
application *application.Application) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/v1/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// edit
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
app.Post("/v1/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// files
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
// completion
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/engines/:model/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// images
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
if application.ApplicationConfig().ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig().ImageDir)
}
if appConfig.AudioDir != "" {
app.Static("/generated-audio", appConfig.AudioDir)
if application.ApplicationConfig().AudioDir != "" {
app.Static("/generated-audio", application.ApplicationConfig().AudioDir)
}
// List models
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
}

View file

@ -194,8 +194,9 @@ diffusers:
pipeline_type: StableDiffusionPipeline
enable_parameters: "negative_prompt,num_inference_steps,clip_skip"
scheduler_type: "k_dpmpp_sde"
cfg_scale: 8
clip_skip: 11
cfg_scale: 8
```
#### Configuration parameters
@ -302,7 +303,8 @@ cuda: true
diffusers:
pipeline_type: StableDiffusionDepth2ImgPipeline
enable_parameters: "negative_prompt,num_inference_steps,image"
cfg_scale: 6
cfg_scale: 6
```
```bash

View file

@ -10,13 +10,13 @@ ico = "rocket_launch"
For installing LocalAI in Kubernetes, the deployment file from the `examples` can be used and customized as prefered:
```
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI/master/examples/kubernetes/deployment.yaml
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI-examples/refs/heads/main/kubernetes/deployment.yaml
```
For Nvidia GPUs:
```
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI/master/examples/kubernetes/deployment-nvidia.yaml
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI-examples/refs/heads/main/kubernetes/deployment-nvidia.yaml
```
Alternatively, the [helm chart](https://github.com/go-skynet/helm-charts) can be used as well:

View file

@ -6,7 +6,7 @@ weight = 24
url = "/model-compatibility/"
+++
Besides llama based models, LocalAI is compatible also with other architectures. The table below lists all the compatible models families and the associated binding repository.
Besides llama based models, LocalAI is compatible also with other architectures. The table below lists all the backends, compatible models families and the associated repository.
{{% alert note %}}
@ -16,19 +16,8 @@ LocalAI will attempt to automatically load models which are not explicitly confi
| Backend and Bindings | Compatible models | Completion/Chat endpoint | Capability | Embeddings support | Token stream support | Acceleration |
|----------------------------------------------------------------------------------|-----------------------|--------------------------|---------------------------|-----------------------------------|----------------------|--------------|
| [llama.cpp]({{%relref "docs/features/text-generation#llama.cpp" %}}) | Vicuna, Alpaca, LLaMa, Falcon, Starcoder, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
| [gpt4all-llama](https://github.com/nomic-ai/gpt4all) | Vicuna, Alpaca, LLaMa | yes | GPT | no | yes | N/A |
| [gpt4all-mpt](https://github.com/nomic-ai/gpt4all) | MPT | yes | GPT | no | yes | N/A |
| [gpt4all-j](https://github.com/nomic-ai/gpt4all) | GPT4ALL-J | yes | GPT | no | yes | N/A |
| [falcon-ggml](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Falcon (*) | yes | GPT | no | no | N/A |
| [dolly](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Dolly | yes | GPT | no | no | N/A |
| [gptj](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | GPTJ | yes | GPT | no | no | N/A |
| [mpt](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | MPT | yes | GPT | no | no | N/A |
| [replit](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Replit | yes | GPT | no | no | N/A |
| [gptneox](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | GPT NeoX, RedPajama, StableLM | yes | GPT | no | no | N/A |
| [bloomz](https://github.com/NouamaneTazi/bloomz.cpp) ([binding](https://github.com/go-skynet/bloomz.cpp)) | Bloom | yes | GPT | no | no | N/A |
| [rwkv](https://github.com/saharNooby/rwkv.cpp) ([binding](https://github.com/donomii/go-rwkv.cpp)) | rwkv | yes | GPT | no | yes | N/A |
| [bert](https://github.com/skeskinen/bert.cpp) ([binding](https://github.com/go-skynet/go-bert.cpp)) | bert | no | Embeddings only | yes | no | N/A |
| [llama.cpp]({{%relref "docs/features/text-generation#llama.cpp" %}}) | LLama, Mamba, RWKV, Falcon, Starcoder, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
| [llama.cpp's ggml model (backward compatibility with old format, before GGUF)](https://github.com/ggerganov/llama.cpp) ([binding](https://github.com/go-skynet/go-llama.cpp)) | LLama, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
| [whisper](https://github.com/ggerganov/whisper.cpp) | whisper | no | Audio | no | no | N/A |
| [stablediffusion](https://github.com/EdVince/Stable-Diffusion-NCNN) ([binding](https://github.com/mudler/go-stable-diffusion)) | stablediffusion | no | Image | no | no | N/A |
| [langchain-huggingface](https://github.com/tmc/langchaingo) | Any text generators available on HuggingFace through API | yes | GPT | no | no | N/A |
@ -40,11 +29,18 @@ LocalAI will attempt to automatically load models which are not explicitly confi
| `diffusers` | SD,... | no | Image generation | no | no | N/A |
| `vall-e-x` | Vall-E | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
| `vllm` | Various GPTs and quantization formats | yes | GPT | no | no | CPU/CUDA |
| `mamba` | Mamba models architecture | yes | GPT | no | no | CPU/CUDA |
| `exllama2` | GPTQ | yes | GPT only | no | no | N/A |
| `transformers-musicgen` | | no | Audio generation | no | no | N/A |
| [tinydream](https://github.com/symisc/tiny-dream#tiny-dreaman-embedded-header-only-stable-diffusion-inference-c-librarypixlabiotiny-dream) | stablediffusion | no | Image | no | no | N/A |
| `coqui` | Coqui | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
| `openvoice` | Open voice | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
| `parler-tts` | Open voice | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
| [rerankers](https://github.com/AnswerDotAI/rerankers) | Reranking API | no | Reranking | no | no | CPU/CUDA |
| `transformers` | Various GPTs and quantization formats | yes | GPT, embeddings | yes | yes**** | CPU/CUDA/XPU |
| [bark-cpp](https://github.com/PABannier/bark.cpp) | bark | no | Audio-Only | no | no | yes |
| [stablediffusion-cpp](https://github.com/leejet/stable-diffusion.cpp) | stablediffusion-1, stablediffusion-2, stablediffusion-3, flux, PhotoMaker | no | Image | no | no | N/A |
| [silero-vad](https://github.com/snakers4/silero-vad) with [Golang bindings](https://github.com/streamer45/silero-vad-go) | Silero VAD | no | Voice Activity Detection | no | no | CPU |
Note: any backend name listed above can be used in the `backend` field of the model configuration file (See [the advanced section]({{%relref "docs/advanced" %}})).

View file

@ -1,3 +1,3 @@
{
"version": "v2.23.0"
"version": "v2.24.2"
}

@ -1 +1 @@
Subproject commit 28fce6b04c414523280c53ee02f9f3a94d9d23da
Subproject commit bd1f3d3432632c61bb12e7ec0f7673fed0289f19

12
gallery/flux-ggml.yaml Normal file
View file

@ -0,0 +1,12 @@
---
name: "flux-ggml"
config_file: |
backend: stablediffusion-ggml
step: 25
options:
- "diffusion_model"
- "clip_l_path:clip_l.safetensors"
- "t5xxl_path:t5xxl_fp16.safetensors"
- "vae_path:ae.safetensors"
- "sampler:euler"

View file

@ -11,4 +11,5 @@ config_file: |
cuda: true
enable_parameters: num_inference_steps
pipeline_type: FluxPipeline
cfg_scale: 0
cfg_scale: 0

File diff suppressed because it is too large Load diff

View file

@ -16,6 +16,7 @@ config_file: |
stopwords:
- 'Assistant:'
- '<s>'
template:
chat: "{{.Input}}\nAssistant: "

5
go.mod
View file

@ -76,6 +76,7 @@ require (
cloud.google.com/go/auth v0.4.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@ -84,8 +85,12 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
github.com/pion/datachannel v1.5.8 // indirect
github.com/pion/dtls/v2 v2.2.12 // indirect
github.com/pion/ice/v2 v2.3.34 // indirect

12
go.sum
View file

@ -140,6 +140,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
@ -268,6 +270,7 @@ github.com/google/go-containerregistry v0.19.2 h1:TannFKE1QSajsP6hPWb5oJNgKe1IKj
github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI=
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
@ -353,6 +356,8 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
@ -474,8 +479,12 @@ github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
@ -519,6 +528,9 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
github.com/nikolalohinski/gonja/v2 v2.3.2 h1:UgLFfqi7L9XfX0PEcE4eUpvGojVQL5KhBfJJaBp7ZxY=
github.com/nikolalohinski/gonja/v2 v2.3.2/go.mod h1:1Wcc/5huTu6y36e0sOFR1XQoFlylw3c3H3L5WOz0RDg=
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=

View file

@ -9,8 +9,6 @@ import (
"sync"
"time"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
@ -23,7 +21,6 @@ type ModelLoader struct {
ModelPath string
mu sync.Mutex
models map[string]*Model
templates *templates.TemplateCache
wd *WatchDog
}
@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader {
nml := &ModelLoader{
ModelPath: modelPath,
models: make(map[string]*Model),
templates: templates.NewTemplateCache(modelPath),
}
return nml

View file

@ -1,52 +0,0 @@
package model
import (
"fmt"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/templates"
)
// Rather than pass an interface{} to the prompt template:
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
type PromptTemplateData struct {
SystemPrompt string
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
Input string
Instruction string
Functions []functions.Function
MessageIndex int
}
type ChatMessageTemplateData struct {
SystemPrompt string
Role string
RoleName string
FunctionName string
Content string
MessageIndex int
Function bool
FunctionCall interface{}
LastMessage bool
}
const (
ChatPromptTemplate templates.TemplateType = iota
ChatMessageTemplate
CompletionPromptTemplate
EditPromptTemplate
FunctionsPromptTemplate
)
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
// TODO: should this check be improved?
if templateType == ChatMessageTemplate {
return "", fmt.Errorf("invalid templateType: ChatMessage")
}
return ml.templates.EvaluateTemplate(templateType, templateName, in)
}
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
}

View file

@ -1,197 +0,0 @@
package model_test
import (
. "github.com/mudler/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .Content}}
{{.Content }}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>`
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
{{ if .FunctionCall -}}
Function call:
{{ else if eq .RoleName "tool" -}}
Function response:
{{ end -}}
{{ if .Content -}}
{{.Content -}}
{{ else if .FunctionCall -}}
{{ toJson .FunctionCall -}}
{{ end -}}
<|eot_id|>`
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"template": llama3,
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "user",
RoleName: "user",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"assistant": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_call": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "",
FunctionCall: map[string]string{"function": "test"},
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_response": {
"template": llama3,
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "tool",
RoleName: "tool",
Content: "Response from tool",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
}
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"template": chatML,
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "user",
RoleName: "user",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"assistant": {
"template": chatML,
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_call": {
"template": chatML,
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "",
FunctionCall: map[string]string{"function": "test"},
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_response": {
"template": chatML,
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "tool",
RoleName: "tool",
Content: "Response from tool",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
}
var _ = Describe("Templates", func() {
Context("chat message ChatML", func() {
var modelLoader *ModelLoader
BeforeEach(func() {
modelLoader = NewModelLoader("")
})
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
Context("chat message llama3", func() {
var modelLoader *ModelLoader
BeforeEach(func() {
modelLoader = NewModelLoader("")
})
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
})

View file

@ -11,59 +11,41 @@ import (
"github.com/mudler/LocalAI/pkg/utils"
"github.com/Masterminds/sprig/v3"
"github.com/nikolalohinski/gonja/v2"
"github.com/nikolalohinski/gonja/v2/exec"
)
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
type TemplateType int
type TemplateCache struct {
mu sync.Mutex
templatesPath string
templates map[TemplateType]map[string]*template.Template
type templateCache struct {
mu sync.Mutex
templatesPath string
templates map[TemplateType]map[string]*template.Template
jinjaTemplates map[TemplateType]map[string]*exec.Template
}
func NewTemplateCache(templatesPath string) *TemplateCache {
tc := &TemplateCache{
templatesPath: templatesPath,
templates: make(map[TemplateType]map[string]*template.Template),
func newTemplateCache(templatesPath string) *templateCache {
tc := &templateCache{
templatesPath: templatesPath,
templates: make(map[TemplateType]map[string]*template.Template),
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
}
return tc
}
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
if _, ok := tc.templates[tt]; !ok {
tc.templates[tt] = make(map[string]*template.Template)
}
}
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateName]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateName)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateName)
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
func (tc *templateCache) existsInModelPath(s string) bool {
return utils.ExistsInPath(tc.templatesPath, s)
}
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.templates[templateType][templateName]; ok {
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
return fmt.Errorf("template file outside path: %s", file)
}
// can either be a file in the system or a string with the template
if tc.existsInModelPath(modelTemplateFile) {
d, err := os.ReadFile(file)
if err != nil {
return err
}
dat = string(d)
} else {
dat = templateName
}
// Parse the template
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
if err != nil {
return err
}
tc.templates[templateType][templateName] = tmpl
return nil
}
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
if _, ok := tc.jinjaTemplates[tt]; !ok {
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
}
}
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
return nil
}
// Check if the model path exists
// skip any error here - we run anyway if a template does not exist
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
dat := ""
file := filepath.Join(tc.templatesPath, modelTemplateFile)
// Security check
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
return fmt.Errorf("template file outside path: %s", file)
}
// can either be a file in the system or a string with the template
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
d, err := os.ReadFile(file)
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
dat = templateName
}
// Parse the template
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
tmpl, err := gonja.FromString(dat)
if err != nil {
return err
}
tc.templates[templateType][templateName] = tmpl
tc.jinjaTemplates[templateType][templateName] = tmpl
return nil
}
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeJinjaTemplateMapKey(templateType)
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
data := exec.NewContext(in)
if err := m.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}

View file

@ -1,73 +0,0 @@
package templates_test
import (
"os"
"path/filepath"
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TemplateCache", func() {
var (
templateCache *templates.TemplateCache
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "templates")
Expect(err).NotTo(HaveOccurred())
// Writing example template files
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
Expect(err).NotTo(HaveOccurred())
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
Expect(err).NotTo(HaveOccurred())
templateCache = templates.NewTemplateCache(tempDir)
})
AfterEach(func() {
os.RemoveAll(tempDir) // Clean up
})
Describe("EvaluateTemplate", func() {
Context("when template is loaded successfully", func() {
It("should evaluate the template correctly", func() {
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("Hello, Gopher!"))
})
})
Context("when template isn't a file", func() {
It("should parse from string", func() {
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("Gopher"))
})
})
Context("when template is empty", func() {
It("should return an empty string", func() {
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(""))
})
})
})
Describe("concurrency", func() {
It("should handle multiple concurrent accesses", func(done Done) {
go func() {
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
}()
go func() {
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
}()
close(done)
}, 0.1) // timeout in seconds
})
})

295
pkg/templates/evaluator.go Normal file
View file

@ -0,0 +1,295 @@
package templates
import (
"encoding/json"
"fmt"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/rs/zerolog/log"
)
// Rather than pass an interface{} to the prompt template:
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
type PromptTemplateData struct {
SystemPrompt string
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
Input string
Instruction string
Functions []functions.Function
MessageIndex int
}
type ChatMessageTemplateData struct {
SystemPrompt string
Role string
RoleName string
FunctionName string
Content string
MessageIndex int
Function bool
FunctionCall interface{}
LastMessage bool
}
const (
ChatPromptTemplate TemplateType = iota
ChatMessageTemplate
CompletionPromptTemplate
EditPromptTemplate
FunctionsPromptTemplate
)
type Evaluator struct {
cache *templateCache
}
func NewEvaluator(modelPath string) *Evaluator {
return &Evaluator{
cache: newTemplateCache(modelPath),
}
}
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
template := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
template = config.Model
}
switch templateType {
case CompletionPromptTemplate:
if config.TemplateConfig.Completion != "" {
template = config.TemplateConfig.Completion
}
case EditPromptTemplate:
if config.TemplateConfig.Edit != "" {
template = config.TemplateConfig.Edit
}
case ChatPromptTemplate:
if config.TemplateConfig.Chat != "" {
template = config.TemplateConfig.Chat
}
case FunctionsPromptTemplate:
if config.TemplateConfig.Functions != "" {
template = config.TemplateConfig.Functions
}
}
if template == "" {
return in.Input, nil
}
if config.TemplateConfig.JinjaTemplate {
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
}
return e.cache.evaluateTemplate(templateType, template, in)
}
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
}
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
conversation := make(map[string]interface{})
messages := make([]map[string]interface{}, len(messageData))
// convert from ChatMessageTemplateData to what the jinja template expects
for _, message := range messageData {
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
var data []byte
data, _ = json.Marshal(message.FunctionCall)
messages = append(messages, map[string]interface{}{
"role": message.RoleName,
"content": message.Content,
"tool_call": string(data),
})
}
conversation["messages"] = messages
// if tools are detected, add these
if len(funcs) > 0 {
conversation["tools"] = funcs
}
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
}
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
conversation := make(map[string]interface{})
conversation["system_prompt"] = in.SystemPrompt
conversation["content"] = in.Input
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
}
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
if config.TemplateConfig.JinjaTemplate {
var messageData []ChatMessageTemplateData
for messageIndex, i := range messages {
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
messageData = append(messageData, ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: config.Roles[i.Role],
RoleName: i.Role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
MessageIndex: messageIndex,
})
}
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
if err == nil {
return templatedInput
}
}
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAnyRole(i.ToolCalls)
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
joinCharacter := "\n"
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
}
predInput = strings.Join(mess, joinCharacter)
log.Debug().Msgf("Prompt (before templating): %s", predInput)
promptTemplate := ChatPromptTemplate
if config.TemplateConfig.Functions != "" && shouldUseFn {
promptTemplate = FunctionsPromptTemplate
}
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
return predInput
}

View file

@ -0,0 +1,253 @@
package templates_test
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
. "github.com/mudler/LocalAI/pkg/templates"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .Content}}
{{.Content }}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>`
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
{{ if .FunctionCall -}}
Function call:
{{ else if eq .RoleName "tool" -}}
Function response:
{{ end -}}
{{ if .Content -}}
{{.Content -}}
{{ else if .FunctionCall -}}
{{ toJson .FunctionCall -}}
{{ end -}}
<|eot_id|>`
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
"assistant": {
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
"shouldUseFn": false,
},
"function_call": {
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
FunctionCall: map[string]string{"function": "test"},
},
},
"shouldUseFn": false,
},
"function_response": {
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "tool",
StringContent: "Response from tool",
},
},
"shouldUseFn": false,
},
}
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
"assistant": {
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
"shouldUseFn": false,
},
"function_call": {
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{
{
Name: "test",
Description: "test",
Parameters: nil,
},
},
"shouldUseFn": true,
"messages": []schema.Message{
{
Role: "assistant",
FunctionCall: map[string]string{"function": "test"},
},
},
},
"function_response": {
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "tool",
StringContent: "Response from tool",
},
},
},
}
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: toolCallJinja,
JinjaTemplate: true,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
}
var _ = Describe("Templates", func() {
Context("chat message ChatML", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
Context("chat message llama3", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
Context("chat message jinja", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range jinjaTest {
foo := jinjaTest[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
})

View file

@ -14,6 +14,7 @@ roles:
stopwords:
- 'Assistant:'
- '<s>'
template:
chat: |