feat: add initial AutoGPTQ backend implementation

This commit is contained in:
Ettore Di Giacinto 2023-08-07 22:39:10 +02:00
parent 91d49cfe9f
commit a843e64fc2
37 changed files with 660 additions and 148 deletions

View file

@ -26,7 +26,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
model.WithLoadGRPCLLMModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithModelFile(modelFile),
model.WithModel(modelFile),
model.WithContext(o.Context),
}

View file

@ -20,7 +20,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModelFile(c.ImageGenerationAssets),
model.WithModel(c.ImageGenerationAssets),
}
for k, v := range o.ExternalGRPCBackends {

View file

@ -27,7 +27,7 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
model.WithLoadGRPCLLMModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModelFile(modelFile),
model.WithModel(modelFile),
model.WithContext(o.Context),
}

View file

@ -19,6 +19,9 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
Seed: int32(c.Seed),
NBatch: int32(b),
NGQA: c.NGQA,
ModelBaseName: c.ModelBaseName,
Device: c.Device,
UseTriton: c.Triton,
RMSNormEps: c.RMSNormEps,
F16Memory: c.F16,
MLock: c.MMlock,

View file

@ -15,7 +15,7 @@ import (
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
opts := []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModelFile(c.Model),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),

View file

@ -54,6 +54,11 @@ type Config struct {
RMSNormEps float32 `yaml:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa"`
// AutoGPTQ
ModelBaseName string `yaml:"model_base_name"`
Device string `yaml:"device"`
Triton bool `yaml:"triton"`
}
type Functions struct {

View file

@ -2,6 +2,7 @@ package openai
import (
"context"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grammar"
@ -106,4 +107,9 @@ type OpenAIRequest struct {
Grammar string `json:"grammar" yaml:"grammar"`
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
Backend string `json:"backend" yaml:"backend"`
// AutoGPTQ
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
}

View file

@ -71,6 +71,14 @@ func updateConfig(config *config.Config, input *OpenAIRequest) {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ModelBaseName != "" {
config.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}