mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-23 03:55:00 +00:00
feat: Add Diffusers (#874)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
93a4bec06b
commit
8c781a6a44
21 changed files with 741 additions and 217 deletions
|
@ -23,7 +23,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithLoadGRPCLLMModelOpts(grpcOpts),
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package backend
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
@ -11,16 +10,18 @@ import (
|
|||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
||||
if c.Backend != model.StableDiffusionBackend {
|
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||
}
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithContext(o.Context),
|
||||
model.WithModel(c.ImageGenerationAssets),
|
||||
model.WithModel(c.Model),
|
||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||
CUDA: c.Diffusers.CUDA,
|
||||
SchedulerType: c.Diffusers.SchedulerType,
|
||||
PipelineType: c.Diffusers.PipelineType,
|
||||
}),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
|
|
|
@ -24,7 +24,7 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
|
|||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithLoadGRPCLLMModelOpts(grpcOpts),
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
|
|
|
@ -15,27 +15,29 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
|||
b = c.Batch
|
||||
}
|
||||
return &pb.ModelOptions{
|
||||
ContextSize: int32(c.ContextSize),
|
||||
Seed: int32(c.Seed),
|
||||
NBatch: int32(b),
|
||||
NGQA: c.NGQA,
|
||||
ModelBaseName: c.ModelBaseName,
|
||||
UseFastTokenizer: c.UseFastTokenizer,
|
||||
Device: c.Device,
|
||||
UseTriton: c.Triton,
|
||||
RMSNormEps: c.RMSNormEps,
|
||||
F16Memory: c.F16,
|
||||
MLock: c.MMlock,
|
||||
RopeFreqBase: c.RopeFreqBase,
|
||||
RopeFreqScale: c.RopeFreqScale,
|
||||
NUMA: c.NUMA,
|
||||
Embeddings: c.Embeddings,
|
||||
LowVRAM: c.LowVRAM,
|
||||
NGPULayers: int32(c.NGPULayers),
|
||||
MMap: c.MMap,
|
||||
MainGPU: c.MainGPU,
|
||||
Threads: int32(c.Threads),
|
||||
TensorSplit: c.TensorSplit,
|
||||
ContextSize: int32(c.ContextSize),
|
||||
Seed: int32(c.Seed),
|
||||
NBatch: int32(b),
|
||||
NGQA: c.NGQA,
|
||||
|
||||
RMSNormEps: c.RMSNormEps,
|
||||
F16Memory: c.F16,
|
||||
MLock: c.MMlock,
|
||||
RopeFreqBase: c.RopeFreqBase,
|
||||
RopeFreqScale: c.RopeFreqScale,
|
||||
NUMA: c.NUMA,
|
||||
Embeddings: c.Embeddings,
|
||||
LowVRAM: c.LowVRAM,
|
||||
NGPULayers: int32(c.NGPULayers),
|
||||
MMap: c.MMap,
|
||||
MainGPU: c.MainGPU,
|
||||
Threads: int32(c.Threads),
|
||||
TensorSplit: c.TensorSplit,
|
||||
// AutoGPTQ
|
||||
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
||||
Device: c.AutoGPTQ.Device,
|
||||
UseTriton: c.AutoGPTQ.Triton,
|
||||
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,9 +64,9 @@ func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
|||
RopeFreqBase: c.RopeFreqBase,
|
||||
RopeFreqScale: c.RopeFreqScale,
|
||||
NegativePrompt: c.NegativePrompt,
|
||||
Mirostat: int32(c.Mirostat),
|
||||
MirostatETA: float32(c.MirostatETA),
|
||||
MirostatTAU: float32(c.MirostatTAU),
|
||||
Mirostat: int32(c.LLMConfig.Mirostat),
|
||||
MirostatETA: float32(c.LLMConfig.MirostatETA),
|
||||
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
|
||||
Debug: c.Debug,
|
||||
StopPrompts: c.StopWords,
|
||||
Repeat: int32(c.RepeatPenalty),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue