mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-30 06:30:43 +00:00
squash on a reasonable progress checkpoint
This commit is contained in:
parent
b8f52d67e1
commit
4d48b362f6
10 changed files with 1067 additions and 252 deletions
329
apiv2/config.go
Normal file
329
apiv2/config.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
OpenAIRequest `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
StopWords []string `yaml:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings"`
|
||||
TrimSpace []string `yaml:"trimspace"`
|
||||
ContextSize int `yaml:"context_size"`
|
||||
F16 bool `yaml:"f16"`
|
||||
Threads int `yaml:"threads"`
|
||||
Debug bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
||||
Mirostat int `yaml:"mirostat"`
|
||||
NGPULayers int `yaml:"gpu_layers"`
|
||||
ImageGenerationAssets string `yaml:"asset_dir"`
|
||||
PromptStrings, InputStrings []string
|
||||
InputToken [][]int
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Completion string `yaml:"completion"`
|
||||
Chat string `yaml:"chat"`
|
||||
Edit string `yaml:"edit"`
|
||||
}
|
||||
|
||||
type ConfigMerger struct {
|
||||
configs map[string]Config
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewConfigMerger() *ConfigMerger {
|
||||
return &ConfigMerger{
|
||||
configs: make(map[string]Config),
|
||||
}
|
||||
}
|
||||
func ReadConfigFile(file string) ([]*Config, error) {
|
||||
c := &[]*Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func ReadConfig(file string) (*Config, error) {
|
||||
c := &Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (cm ConfigMerger) LoadConfigFile(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfigFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
cm.configs[cc.Name] = *cc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm ConfigMerger) LoadConfig(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfig(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
|
||||
cm.configs[c.Name] = *c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
v, exists := cm.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (cm ConfigMerger) ListConfigs() []string {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []string
|
||||
for k := range cm.configs {
|
||||
res = append(res, k)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm ConfigMerger) LoadConfigs(path string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
files, err := ioutil.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") {
|
||||
continue
|
||||
}
|
||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
||||
if err == nil {
|
||||
cm.configs[c.Name] = *c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateConfig(config *Config, input *OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
|
||||
input := new(OpenAIRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
modelFile := input.Model
|
||||
|
||||
if c.Params("model") != "" {
|
||||
modelFile = c.Params("model")
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel {
|
||||
models, _ := loader.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelFile = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", nil, fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelFile = bearer
|
||||
}
|
||||
return modelFile, input, nil
|
||||
}
|
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
|
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
var config *Config
|
||||
cfg, exists := cm.GetConfig(modelFile)
|
||||
if !exists {
|
||||
config = &Config{
|
||||
OpenAIRequest: defaultRequest(modelFile),
|
||||
ContextSize: ctx,
|
||||
Threads: threads,
|
||||
F16: f16,
|
||||
Debug: debug,
|
||||
}
|
||||
} else {
|
||||
config = &cfg
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(config, input)
|
||||
|
||||
// Don't allow 0 as setting
|
||||
if config.Threads == 0 {
|
||||
if threads != 0 {
|
||||
config.Threads = threads
|
||||
} else {
|
||||
config.Threads = 4
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug {
|
||||
config.Debug = true
|
||||
}
|
||||
|
||||
return config, input, nil
|
||||
}
|
194
apiv2/localai.go
Normal file
194
apiv2/localai.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
package apiv2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type LocalAIServer struct {
|
||||
}
|
||||
|
||||
var _ ServerInterface = (*LocalAIServer)(nil)
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func sendError(w http.ResponseWriter, code int, message string) {
|
||||
localAiError := Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
json.NewEncoder(w).Encode(localAiError)
|
||||
}
|
||||
|
||||
// It won't work, but it's worth a try.
|
||||
const nyiErrorMessageFormatString = "%s is not yet implemented by LocalAI\nThere is no need to contact support about this error and retrying will not help.\nExpect an update at https://github.com/go-skynet/LocalAI if this changes!"
|
||||
|
||||
// Do we want or need an additional "wontfix" template that is even stronger than this?
|
||||
const nyiDepreciatedErrorMessageFormatString = "%s is a depreciated portion of the OpenAI API, and is not yet implemented by LocalAI\nThere is no need to contact support about this error and retrying will not help."
|
||||
|
||||
// CancelFineTune implements ServerInterface
|
||||
func (*LocalAIServer) CancelFineTune(w http.ResponseWriter, r *http.Request, fineTuneId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Fine Tune"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateAnswer implements ServerInterface
|
||||
func (*LocalAIServer) CreateAnswer(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "CreateAnswer"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateChatCompletion implements ServerInterface
|
||||
func (*LocalAIServer) CreateChatCompletion(w http.ResponseWriter, r *http.Request) {
|
||||
var chatRequest CreateChatCompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&chatRequest); err != nil {
|
||||
sendError(w, http.StatusBadRequest, "Invalid CreateChatCompletionRequest")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CreateClassification implements ServerInterface
|
||||
func (*LocalAIServer) CreateClassification(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "CreateClassification"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateCompletion implements ServerInterface
|
||||
func (*LocalAIServer) CreateCompletion(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// CreateEdit implements ServerInterface
|
||||
func (*LocalAIServer) CreateEdit(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "CreateEdit"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateEmbedding implements ServerInterface
|
||||
func (*LocalAIServer) CreateEmbedding(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// CreateFile implements ServerInterface
|
||||
func (*LocalAIServer) CreateFile(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Create File"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateFineTune implements ServerInterface
|
||||
func (*LocalAIServer) CreateFineTune(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Create Fine Tune"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateImage implements ServerInterface
|
||||
func (*LocalAIServer) CreateImage(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// CreateImageEdit implements ServerInterface
|
||||
func (*LocalAIServer) CreateImageEdit(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// CreateImageVariation implements ServerInterface
|
||||
func (*LocalAIServer) CreateImageVariation(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// CreateModeration implements ServerInterface
|
||||
func (*LocalAIServer) CreateModeration(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "CreateModeration"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateSearch implements ServerInterface
|
||||
func (*LocalAIServer) CreateSearch(w http.ResponseWriter, r *http.Request, engineId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "CreateSearch"))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateTranscription implements ServerInterface
|
||||
func (*LocalAIServer) CreateTranscription(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// CreateTranslation implements ServerInterface
|
||||
func (*LocalAIServer) CreateTranslation(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// DeleteFile implements ServerInterface
|
||||
func (*LocalAIServer) DeleteFile(w http.ResponseWriter, r *http.Request, fileId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "DeleteFile"))
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteModel implements ServerInterface
|
||||
func (*LocalAIServer) DeleteModel(w http.ResponseWriter, r *http.Request, model string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "DeleteModel"))
|
||||
return
|
||||
}
|
||||
|
||||
// DownloadFile implements ServerInterface
|
||||
func (*LocalAIServer) DownloadFile(w http.ResponseWriter, r *http.Request, fileId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "DownloadFile"))
|
||||
return
|
||||
}
|
||||
|
||||
// ListEngines implements ServerInterface
|
||||
func (*LocalAIServer) ListEngines(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "List Engines"))
|
||||
return
|
||||
}
|
||||
|
||||
// ListFiles implements ServerInterface
|
||||
func (*LocalAIServer) ListFiles(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "ListFiles"))
|
||||
return
|
||||
}
|
||||
|
||||
// ListFineTuneEvents implements ServerInterface
|
||||
func (*LocalAIServer) ListFineTuneEvents(w http.ResponseWriter, r *http.Request, fineTuneId string, params ListFineTuneEventsParams) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "List Fine Tune Events"))
|
||||
return
|
||||
}
|
||||
|
||||
// ListFineTunes implements ServerInterface
|
||||
func (*LocalAIServer) ListFineTunes(w http.ResponseWriter, r *http.Request) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "List Fine Tunes"))
|
||||
return
|
||||
}
|
||||
|
||||
// ListModels implements ServerInterface
|
||||
func (*LocalAIServer) ListModels(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// RetrieveEngine implements ServerInterface
|
||||
func (*LocalAIServer) RetrieveEngine(w http.ResponseWriter, r *http.Request, engineId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "RetrieveEngine"))
|
||||
return
|
||||
}
|
||||
|
||||
// RetrieveFile implements ServerInterface
|
||||
func (*LocalAIServer) RetrieveFile(w http.ResponseWriter, r *http.Request, fileId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "RetrieveFile"))
|
||||
return
|
||||
}
|
||||
|
||||
// RetrieveFineTune implements ServerInterface
|
||||
func (*LocalAIServer) RetrieveFineTune(w http.ResponseWriter, r *http.Request, fineTuneId string) {
|
||||
sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Retrieve Fine Tune"))
|
||||
return
|
||||
}
|
||||
|
||||
// RetrieveModel implements ServerInterface
|
||||
func (*LocalAIServer) RetrieveModel(w http.ResponseWriter, r *http.Request, model string) {
|
||||
panic("unimplemented")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue