mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(diffusers): add DPMSolverMultistepScheduler++, DPMSolverMultistepSchedulerSDE++, guidance_scale (#903)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
0ec695f9e4
commit
37700f2d98
11 changed files with 222 additions and 196 deletions
|
@ -36,6 +36,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
|
||||
local = False
|
||||
modelFile = request.Model
|
||||
|
||||
cfg_scale = 7
|
||||
if request.CFGScale != 0:
|
||||
cfg_scale = request.CFGScale
|
||||
|
||||
# Check if ModelFile exists
|
||||
if request.ModelFile != "":
|
||||
if os.path.exists(request.ModelFile):
|
||||
|
@ -52,30 +57,35 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = DiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType, use_safetensors=True)
|
||||
torch_dtype=torchType, use_safetensors=True,
|
||||
guidance_scale=cfg_scale)
|
||||
else:
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
)
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
|
@ -83,7 +93,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
||||
if request.SchedulerType == "DPMSolverMultistepScheduler":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
||||
|
||||
if request.SchedulerType == "DPMSolverMultistepScheduler++":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config,algorithm_type="dpmsolver++")
|
||||
if request.SchedulerType == "DPMSolverMultistepSchedulerSDE++":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, algorithm_type="sde-dpmsolver++")
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
except Exception as err:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue